More flexibility with text encoder return values.

Text encoders can now return other values to the CONDITIONING than the cond
and pooled output.
This commit is contained in:
comfyanonymous 2024-07-10 20:06:50 -04:00
parent e44fa5667f
commit 391c1046cf
3 changed files with 30 additions and 8 deletions

View File

@ -130,7 +130,7 @@ class CLIP:
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False):
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
@ -140,7 +140,15 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
if len(o) > 2:
for k in o[2]:
out[k] = o[2][k]
return out
if return_pooled:
return cond, pooled
return cond

View File

@ -62,7 +62,16 @@ class ClipTokenWeightEncoder:
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
extra[k] = v
r = r + (extra,)
return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
@ -206,8 +215,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
elif outputs[2] is not None:
pooled_output = outputs[2].float()
extra = {}
if self.return_attention_masks:
return z, pooled_output, attention_mask
extra["attention_mask"] = attention_mask
if len(extra) > 0:
return z, pooled_output, extra
return z, pooled_output
@ -547,8 +560,8 @@ class SD1ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out, pooled
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)

View File

@ -55,8 +55,9 @@ class CLIPTextEncode:
def encode(self, clip, text):
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled}]], )
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
return ([[cond, output]], )
class ConditioningCombine:
@classmethod