mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support returning text encoder attention masks.
This commit is contained in:
parent
90389b3b8a
commit
e44fa5667f
@ -38,7 +38,9 @@ class ClipTokenWeightEncoder:
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
o = self.encode(to_encode)
|
||||
out, pooled = o[:2]
|
||||
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
@ -57,8 +59,11 @@ class ClipTokenWeightEncoder:
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||
else:
|
||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
|
||||
return r
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
@ -70,7 +75,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True): # clip-vit-base-patch32
|
||||
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
||||
@ -96,6 +101,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
self.return_attention_masks = return_attention_masks
|
||||
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
@ -169,7 +175,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks or self.zero_out_masked:
|
||||
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
end_token = self.special_tokens.get("end", -1)
|
||||
for x in range(attention_mask.shape[0]):
|
||||
@ -200,6 +206,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
|
||||
if self.return_attention_masks:
|
||||
return z, pooled_output, attention_mask
|
||||
|
||||
return z, pooled_output
|
||||
|
||||
def encode(self, tokens):
|
||||
|
Loading…
Reference in New Issue
Block a user