mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Support SDXL inpaint models.
This commit is contained in:
parent
c335fdf200
commit
7931ff0fd9
@ -111,6 +111,9 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||
|
||||
def set_inpaint(self):
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
||||
adm_inputs = []
|
||||
weights = []
|
||||
@ -148,12 +151,6 @@ class SD21UNCLIP(BaseModel):
|
||||
else:
|
||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
|
||||
|
||||
|
||||
class SDInpaint(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
|
||||
def sdxl_pooled(args, noise_augmentor):
|
||||
if "unclip_conditioning" in args:
|
||||
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
|
||||
|
@ -183,8 +183,12 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
||||
|
||||
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet]
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
|
@ -355,13 +355,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||
model_config.unet_config = unet_config
|
||||
|
||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||
model = model_base.SDInpaint(model_config, model_type=model_type)
|
||||
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||
else:
|
||||
model = model_base.BaseModel(model_config, model_type=model_type)
|
||||
|
||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||
model.set_inpaint()
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
|
@ -153,7 +153,10 @@ class SDXL(supported_models_base.BASE):
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
if self.inpaint_model():
|
||||
out.set_inpaint()
|
||||
return out
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
keys_to_replace = {}
|
||||
|
@ -57,12 +57,13 @@ class BASE:
|
||||
self.unet_config[x] = self.unet_extra_config[x]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
if self.inpaint_model():
|
||||
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
elif self.noise_aug_config is not None:
|
||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
if self.noise_aug_config is not None:
|
||||
out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
else:
|
||||
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
if self.inpaint_model():
|
||||
out.set_inpaint()
|
||||
return out
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
Loading…
Reference in New Issue
Block a user