mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
More flexibility with text encoder return values.
Text encoders can now return other values to the CONDITIONING than the cond and pooled output.
This commit is contained in:
parent
e44fa5667f
commit
391c1046cf
12
comfy/sd.py
12
comfy/sd.py
@ -130,7 +130,7 @@ class CLIP:
|
||||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
if self.layer_idx is not None:
|
||||
@ -140,7 +140,15 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
out = {"cond": cond, "pooled_output": pooled}
|
||||
if len(o) > 2:
|
||||
for k in o[2]:
|
||||
out[k] = o[2][k]
|
||||
return out
|
||||
|
||||
if return_pooled:
|
||||
return cond, pooled
|
||||
return cond
|
||||
|
@ -62,7 +62,16 @@ class ClipTokenWeightEncoder:
|
||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||
else:
|
||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
|
||||
|
||||
if len(o) > 2:
|
||||
extra = {}
|
||||
for k in o[2]:
|
||||
v = o[2][k]
|
||||
if k == "attention_mask":
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||
extra[k] = v
|
||||
|
||||
r = r + (extra,)
|
||||
return r
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
@ -206,8 +215,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
|
||||
extra = {}
|
||||
if self.return_attention_masks:
|
||||
return z, pooled_output, attention_mask
|
||||
extra["attention_mask"] = attention_mask
|
||||
|
||||
if len(extra) > 0:
|
||||
return z, pooled_output, extra
|
||||
|
||||
return z, pooled_output
|
||||
|
||||
@ -547,8 +560,8 @@ class SD1ClipModel(torch.nn.Module):
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||
return out, pooled
|
||||
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||
return out
|
||||
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
5
nodes.py
5
nodes.py
@ -55,8 +55,9 @@ class CLIPTextEncode:
|
||||
|
||||
def encode(self, clip, text):
|
||||
tokens = clip.tokenize(text)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled}]], )
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
return ([[cond, output]], )
|
||||
|
||||
class ConditioningCombine:
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user