From 391c1046cff8a3877a2ba343057579ab4278c5b1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 Jul 2024 20:06:50 -0400 Subject: [PATCH] More flexibility with text encoder return values. Text encoders can now return other values to the CONDITIONING than the cond and pooled output. --- comfy/sd.py | 12 ++++++++++-- comfy/sd1_clip.py | 21 +++++++++++++++++---- nodes.py | 5 +++-- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index b39230bd..25454f83 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -130,7 +130,7 @@ class CLIP: def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def encode_from_tokens(self, tokens, return_pooled=False): + def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): self.cond_stage_model.reset_clip_options() if self.layer_idx is not None: @@ -140,7 +140,15 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() - cond, pooled = self.cond_stage_model.encode_token_weights(tokens) + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] + if return_dict: + out = {"cond": cond, "pooled_output": pooled} + if len(o) > 2: + for k in o[2]: + out[k] = o[2][k] + return out + if return_pooled: return cond, pooled return cond diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4da2b46f..565ad69d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -62,7 +62,16 @@ class ClipTokenWeightEncoder: r = (out[-1:].to(model_management.intermediate_device()), first_pooled) else: r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) - r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:])) + + if len(o) > 2: + extra = {} + for k in o[2]: + v = o[2][k] + if k == "attention_mask": + v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()) + extra[k] = v + + r = r + (extra,) return r class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): @@ -206,8 +215,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): elif outputs[2] is not None: pooled_output = outputs[2].float() + extra = {} if self.return_attention_masks: - return z, pooled_output, attention_mask + extra["attention_mask"] = attention_mask + + if len(extra) > 0: + return z, pooled_output, extra return z, pooled_output @@ -547,8 +560,8 @@ class SD1ClipModel(torch.nn.Module): def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs[self.clip_name] - out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) - return out, pooled + out = getattr(self, self.clip).encode_token_weights(token_weight_pairs) + return out def load_sd(self, sd): return getattr(self, self.clip).load_sd(sd) diff --git a/nodes.py b/nodes.py index 8d8b153f..5778f060 100644 --- a/nodes.py +++ b/nodes.py @@ -55,8 +55,9 @@ class CLIPTextEncode: def encode(self, clip, text): tokens = clip.tokenize(text) - cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled}]], ) + output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) + cond = output.pop("cond") + return ([[cond, output]], ) class ConditioningCombine: @classmethod