Refactor: Move some code to the comfy/text_encoders folder.

This commit is contained in:
comfyanonymous 2024-07-15 17:36:24 -04:00
parent 7914c47d5a
commit 1305fb294c
11 changed files with 20 additions and 20 deletions

View File

@ -19,8 +19,8 @@ from . import model_detection
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
from . import sd3_clip import comfy.text_encoders.sd3_clip
from . import sa_t5 import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5 import comfy.text_encoders.aura_t5
import comfy.model_patcher import comfy.model_patcher
@ -414,27 +414,27 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype dtype_t5 = weight.dtype
if weight.shape[-1] == 4096: if weight.shape[-1] == 4096:
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif weight.shape[-1] == 2048: elif weight.shape[-1] == 2048:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]: elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
clip_target.clip = sa_t5.SAT5Model clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = sa_t5.SAT5Tokenizer clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else: else:
clip_target.clip = sd1_clip.SD1ClipModel clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2: elif len(clip_data) == 2:
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
else: else:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3: elif len(clip_data) == 3:
clip_target.clip = sd3_clip.SD3ClipModel clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory) clip = CLIP(clip_target, embedding_directory=embedding_directory)
for c in clip_data: for c in clip_data:

View File

@ -5,8 +5,8 @@ from . import utils
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
from . import sd3_clip import comfy.text_encoders.sd3_clip
from . import sa_t5 import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5 import comfy.text_encoders.aura_t5
from . import supported_models_base from . import supported_models_base
@ -524,7 +524,7 @@ class SD3(supported_models_base.BASE):
t5 = True t5 = True
dtype_t5 = state_dict[t5_key].dtype dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5)) return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
class StableAudio(supported_models_base.BASE): class StableAudio(supported_models_base.BASE):
unet_config = { unet_config = {
@ -555,7 +555,7 @@ class StableAudio(supported_models_base.BASE):
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model) return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model)
class AuraFlow(supported_models_base.BASE): class AuraFlow(supported_models_base.BASE):
unet_config = { unet_config = {

View File

@ -1,12 +1,12 @@
from comfy import sd1_clip from comfy import sd1_clip
from .llama_tokenizer import LLAMATokenizer from .llama_tokenizer import LLAMATokenizer
import comfy.t5 import comfy.text_encoders.t5
import os import os
class PT5XlModel(sd1_clip.SDClipModel): class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class PT5XlTokenizer(sd1_clip.SDTokenizer): class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None):

View File

@ -1,12 +1,12 @@
from comfy import sd1_clip from comfy import sd1_clip
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
import comfy.t5 import comfy.text_encoders.t5
import os import os
class T5BaseModel(sd1_clip.SDClipModel): class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer): class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None):

View File

@ -1,7 +1,7 @@
from comfy import sd1_clip from comfy import sd1_clip
from comfy import sdxl_clip from comfy import sdxl_clip
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
import comfy.t5 import comfy.text_encoders.t5
import torch import torch
import os import os
import comfy.model_management import comfy.model_management
@ -10,7 +10,7 @@ import logging
class T5XXLModel(sd1_clip.SDClipModel): class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None):