From e813abbb2c27448735d02af819e79f1a036b7212 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 15 Sep 2024 07:59:18 -0400 Subject: [PATCH] Long CLIP L support for SDXL, SD3 and Flux. Use the *CLIPLoader nodes. --- comfy/sd.py | 12 +++++------- comfy/sd1_clip.py | 2 ++ comfy/sdxl_clip.py | 9 ++++++--- comfy/text_encoders/flux.py | 6 ++++-- comfy/text_encoders/long_clipl.py | 15 +++++++++++++-- comfy/text_encoders/sd3_clip.py | 9 ++++++--- 6 files changed, 36 insertions(+), 17 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 1e1a6594..07310b9d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -445,12 +445,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer else: w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None) - if w is not None and w.shape[0] == 248: - clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel - clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer - else: - clip_target.clip = sd1_clip.SD1ClipModel - clip_target.tokenizer = sd1_clip.SD1Tokenizer + clip_target.clip = sd1_clip.SD1ClipModel + clip_target.tokenizer = sd1_clip.SD1Tokenizer elif len(clip_data) == 2: if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) @@ -475,10 +471,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer parameters = 0 + tokenizer_data = {} for c in clip_data: parameters += comfy.utils.calculate_parameters(c) + tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options) for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 676653f7..9f0d95b0 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -542,6 +542,7 @@ class SD1Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): self.clip_name = clip_name self.clip = "clip_{}".format(self.clip_name) + tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer) setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)) def tokenize_with_weights(self, text:str, return_word_ids=False): @@ -570,6 +571,7 @@ class SD1ClipModel(torch.nn.Module): self.clip_name = clip_name self.clip = "clip_{}".format(self.clip_name) + clip_model = model_options.get("{}_class".format(self.clip), clip_model) setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs)) self.dtypes = set() diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index a0145caa..4d0a4e8e 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer): class SDXLTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) + clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) + self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False): @@ -40,7 +41,8 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) + clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) + self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.dtypes = set([dtype]) @@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module): token_weight_pairs_l = token_weight_pairs["l"] g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) - return torch.cat([l_out, g_out], dim=-1), g_pooled + cut_to = min(l_out.shape[1], g_out.shape[1]) + return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled def load_sd(self, sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 91d1f249..7c75fe64 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -18,7 +18,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer): class FluxTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) + clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) + self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False): @@ -38,7 +39,8 @@ class FluxClipModel(torch.nn.Module): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) - self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) + self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.dtypes = set([dtype, dtype_t5]) diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py index 4677fb3b..b81912cb 100644 --- a/comfy/text_encoders/long_clipl.py +++ b/comfy/text_encoders/long_clipl.py @@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer): super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) class LongClipModel_(sd1_clip.SDClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}): + def __init__(self, *args, **kwargs): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json") - super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options) + super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs) class LongClipTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer): class LongClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs) + +def model_options_long_clip(sd, tokenizer_data, model_options): + w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None) + if w is None: + w = sd.get("text_model.embeddings.position_embedding.weight", None) + if w is not None and w.shape[0] == 248: + tokenizer_data = tokenizer_data.copy() + model_options = model_options.copy() + tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_ + model_options["clip_l_class"] = LongClipModel_ + return tokenizer_data, model_options diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index e3832ac2..c54f2885 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -20,7 +20,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer): class SD3Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) + clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) + self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) @@ -42,7 +43,8 @@ class SD3ClipModel(torch.nn.Module): super().__init__() self.dtypes = set() if clip_l: - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) + clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) + self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) self.dtypes.add(dtype) else: self.clip_l = None @@ -95,7 +97,8 @@ class SD3ClipModel(torch.nn.Module): if self.clip_g is not None: g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) if lg_out is not None: - lg_out = torch.cat([lg_out, g_out], dim=-1) + cut_to = min(lg_out.shape[1], g_out.shape[1]) + lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1) else: lg_out = torch.nn.functional.pad(g_out, (768, 0)) else: