Add model_options for text encoder.

This commit is contained in:
comfyanonymous 2024-08-17 10:15:13 -04:00
parent 858d51f91a
commit fca42836f2
9 changed files with 66 additions and 57 deletions

View File

@ -62,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP: class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0): def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
if no_init: if no_init:
return return
params = target.params.copy() params = target.params.copy()
@ -71,9 +71,14 @@ class CLIP:
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
dtype = model_management.text_encoder_dtype(load_device) dtype = model_options.get("dtype", None)
if dtype is None:
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)) params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
params['model_options'] = model_options
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
for dt in self.cond_stage_model.dtypes: for dt in self.cond_stage_model.dtypes:
@ -394,7 +399,7 @@ class CLIPType(Enum):
HUNYUAN_DIT = 5 HUNYUAN_DIT = 5
FLUX = 6 FLUX = 6
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
@ -464,7 +469,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
for c in clip_data: for c in clip_data:
parameters += comfy.utils.calculate_parameters(c) parameters += comfy.utils.calculate_parameters(c)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters) clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
for c in clip_data: for c in clip_data:
m, u = clip.load_sd(c) m, u = clip.load_sd(c)
if len(m) > 0: if len(m) > 0:
@ -506,14 +511,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path) sd = comfy.utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
if out is None: if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
@ -563,7 +568,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd) clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0: if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd) parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters) clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True) m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0: if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))

View File

@ -84,7 +84,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32 return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
@ -94,7 +94,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f: with open(textmodel_json_config) as f:
config = json.load(f) config = json.load(f)
self.operations = comfy.ops.manual_cast operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations) self.transformer = model_class(config, dtype, device, self.operations)
self.num_layers = self.transformer.num_layers self.num_layers = self.transformer.num_layers
@ -553,7 +557,7 @@ class SD1Tokenizer:
return {} return {}
class SD1ClipModel(torch.nn.Module): class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
super().__init__() super().__init__()
if name is not None: if name is not None:
@ -563,7 +567,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set() self.dtypes = set()
if dtype is not None: if dtype is not None:

View File

@ -3,14 +3,14 @@ import torch
import os import os
class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, model_options=model_options)
def load_sd(self, sd): def load_sd(self, sd):
return super().load_sd(sd) return super().load_sd(sd)
@ -38,10 +38,10 @@ class SDXLTokenizer:
return {} return {}
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype]) self.dtypes = set([dtype])
def set_clip_options(self, options): def set_clip_options(self, options):
@ -66,8 +66,8 @@ class SDXLClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd) return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
@ -79,14 +79,14 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel): class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, model_options=model_options)
def load_sd(self, sd): def load_sd(self, sd):
return super().load_sd(sd) return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel): class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, model_options=model_options)

View File

@ -4,9 +4,9 @@ import comfy.text_encoders.t5
import os import os
class PT5XlModel(sd1_clip.SDClipModel): class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
class PT5XlTokenizer(sd1_clip.SDTokenizer): class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -18,5 +18,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
class AuraT5Model(sd1_clip.SD1ClipModel): class AuraT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs) super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)

View File

@ -6,9 +6,9 @@ import torch
import os import os
class T5XXLModel(sd1_clip.SDClipModel): class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -35,11 +35,11 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module): class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False) self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5]) self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options): def set_clip_options(self, options):
@ -66,6 +66,6 @@ class FluxClipModel(torch.nn.Module):
def flux_clip(dtype_t5=None): def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel): class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype) super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_ return FluxClipModel_

View File

@ -7,9 +7,9 @@ import os
import torch import torch
class HyditBertModel(sd1_clip.SDClipModel): class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class HyditBertTokenizer(sd1_clip.SDTokenizer): class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -18,9 +18,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
class MT5XLModel(sd1_clip.SDClipModel): class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class MT5XLTokenizer(sd1_clip.SDTokenizer): class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -50,10 +50,10 @@ class HyditTokenizer:
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]} return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
class HyditModel(torch.nn.Module): class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
self.hydit_clip = HyditBertModel(dtype=dtype) self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
self.mt5xl = MT5XLModel(dtype=dtype) self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
self.dtypes = set() self.dtypes = set()
if dtype is not None: if dtype is not None:

View File

@ -4,9 +4,9 @@ import comfy.text_encoders.t5
import os import os
class T5BaseModel(sd1_clip.SDClipModel): class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer): class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -18,5 +18,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
class SAT5Model(sd1_clip.SD1ClipModel): class SAT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs) super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)

View File

@ -2,13 +2,13 @@ from comfy import sd1_clip
import os import os
class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, model_options=model_options)
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
@ -19,5 +19,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel): class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs) super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)

View File

@ -8,9 +8,9 @@ import comfy.model_management
import logging import logging
class T5XXLModel(sd1_clip.SDClipModel): class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -38,24 +38,24 @@ class SD3Tokenizer:
return {} return {}
class SD3ClipModel(torch.nn.Module): class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
self.dtypes = set() self.dtypes = set()
if clip_l: if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype) self.dtypes.add(dtype)
else: else:
self.clip_l = None self.clip_l = None
if clip_g: if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes.add(dtype) self.dtypes.add(dtype)
else: else:
self.clip_g = None self.clip_g = None
if t5: if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes.add(dtype_t5) self.dtypes.add(dtype_t5)
else: else:
self.t5xxl = None self.t5xxl = None
@ -132,6 +132,6 @@ class SD3ClipModel(torch.nn.Module):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None): def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel): class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_ return SD3ClipModel_