mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Fix checkpoint loader with config.
This commit is contained in:
parent
e60ca6929a
commit
2a134bfab9
@ -388,11 +388,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
||||||
clip_target.clip = sd2_clip.SD2ClipModel
|
clip_target.clip = sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
|
w.cond_stage_model = clip.cond_stage_model.clip_h
|
||||||
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
|
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
w.cond_stage_model = clip.cond_stage_model
|
w.cond_stage_model = clip.cond_stage_model.clip_l
|
||||||
load_clip_weights(w, state_dict)
|
load_clip_weights(w, state_dict)
|
||||||
|
|
||||||
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||||
|
@ -472,11 +472,11 @@ class SD1Tokenizer:
|
|||||||
|
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module):
|
class SD1ClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel):
|
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
getattr(self, self.clip).clip_layer(layer_idx)
|
getattr(self, self.clip).clip_layer(layer_idx)
|
||||||
|
@ -21,5 +21,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||||
|
|
||||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel)
|
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user