From 2a134bfab9788b6a0a70aea3172d8e3fc904b414 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 27 Oct 2023 22:13:55 -0400 Subject: [PATCH] Fix checkpoint loader with config. --- comfy/sd.py | 6 ++++-- comfy/sd1_clip.py | 4 ++-- comfy/sd2_clip.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index aea55bbdf..4a2823c9d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -388,11 +388,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_h elif clip_config["target"].endswith("FrozenCLIPEmbedder"): clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_l load_clip_weights(w, state_dict) return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 5368a45df..fdaa1e6c7 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -472,11 +472,11 @@ class SD1Tokenizer: class SD1ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel): + def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs): super().__init__() self.clip_name = clip_name self.clip = "clip_{}".format(self.clip_name) - setattr(self, self.clip, clip_model(device=device, dtype=dtype)) + setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) def clip_layer(self, layer_idx): getattr(self, self.clip).clip_layer(layer_idx) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 9df868b76..ebabf7ccd 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -21,5 +21,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer) class SD2ClipModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None): - super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel) + def __init__(self, device="cpu", dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)