From ce649d61c05ab0e905155b4a010d4f0459a0db8d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 Jul 2024 23:48:17 -0400 Subject: [PATCH] Allow zeroing out of embeds with unused attention mask. --- comfy/sd1_clip.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 78e556b5..3b812b44 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -169,7 +169,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = torch.LongTensor(tokens).to(device) attention_mask = None - if self.enable_attention_masks: + if self.enable_attention_masks or self.zero_out_masked: attention_mask = torch.zeros_like(tokens) end_token = self.special_tokens.get("end", -1) for x in range(attention_mask.shape[0]): @@ -178,7 +178,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if tokens[x, y] == end_token: break - outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + attention_mask_model = None + if self.enable_attention_masks: + attention_mask_model = attention_mask + + outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": @@ -186,7 +190,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: z = outputs[1].float() - if self.zero_out_masked and attention_mask is not None: + if self.zero_out_masked: z *= attention_mask.unsqueeze(-1).float() pooled_output = None