diff --git a/comfy/sd.py b/comfy/sd.py index a63a0d1de..e98dabe88 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -22,6 +22,7 @@ from . import sdxl_clip import comfy.model_patcher import comfy.lora import comfy.t2i_adapter.adapter +import comfy.supported_models_base def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -348,10 +349,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl class EmptyClass: pass - model_config = EmptyClass() - model_config.unet_config = unet_config + model_config = comfy.supported_models_base.BASE({}) + from . import latent_formats model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) + model_config.unet_config = unet_config if config['model']["target"].endswith("LatentInpaintDiffusion"): model = model_base.SDInpaint(model_config, model_type=model_type) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index c72838008..c9cd54d0e 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -1,6 +1,7 @@ import torch from . import model_base from . import utils +from . import latent_formats def state_dict_key_replace(state_dict, keys_to_replace): @@ -34,6 +35,7 @@ class BASE: clip_vision_prefix = None noise_aug_config = None beta_schedule = "linear" + latent_format = latent_formats.LatentFormat @classmethod def matches(s, unet_config):