diff --git a/comfy/model_base.py b/comfy/model_base.py index 677a23de..ca154dba 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -111,6 +111,9 @@ class BaseModel(torch.nn.Module): return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + def set_inpaint(self): + self.concat_keys = ("mask", "masked_image") + def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): adm_inputs = [] weights = [] @@ -148,12 +151,6 @@ class SD21UNCLIP(BaseModel): else: return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) - -class SDInpaint(BaseModel): - def __init__(self, model_config, model_type=ModelType.EPS, device=None): - super().__init__(model_config, model_type, device=device) - self.concat_keys = ("mask", "masked_image") - def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280] diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0edc4f18..372d5a2d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -183,8 +183,12 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16): 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} + SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet] + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint] for unet_config in supported_models: matches = True diff --git a/comfy/sd.py b/comfy/sd.py index e98dabe8..8be0bcbc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -355,13 +355,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) model_config.unet_config = unet_config - if config['model']["target"].endswith("LatentInpaintDiffusion"): - model = model_base.SDInpaint(model_config, model_type=model_type) - elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): + if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) else: model = model_base.BaseModel(model_config, model_type=model_type) + if config['model']["target"].endswith("LatentInpaintDiffusion"): + model.set_inpaint() + if fp16: model = model.half() diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 95fc8f3f..0b3e4bcb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -153,7 +153,10 @@ class SDXL(supported_models_base.BASE): return model_base.ModelType.EPS def get_model(self, state_dict, prefix="", device=None): - return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device) + out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device) + if self.inpaint_model(): + out.set_inpaint() + return out def process_clip_state_dict(self, state_dict): keys_to_replace = {} diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index c9cd54d0..395a90ab 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -57,12 +57,13 @@ class BASE: self.unet_config[x] = self.unet_extra_config[x] def get_model(self, state_dict, prefix="", device=None): - if self.inpaint_model(): - return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device) - elif self.noise_aug_config is not None: - return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device) + if self.noise_aug_config is not None: + out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device) else: - return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device) + out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device) + if self.inpaint_model(): + out.set_inpaint() + return out def process_clip_state_dict(self, state_dict): return state_dict