diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index b84a38490..9978b6c35 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -71,6 +71,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.empty_tokens = [[49406] + [49407] * 76] self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.enable_attention_masks = False self.layer_norm_hidden_state = True if layer == "hidden": @@ -147,7 +148,17 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): precision_scope = lambda a, b: contextlib.nullcontext(a) with precision_scope(model_management.get_autocast_device(device), torch.float32): - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden") self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last":