Refactor inpaint/ip2p code.

This commit is contained in:
comfyanonymous 2024-11-19 02:34:35 -05:00
parent 9cc90ee3eb
commit b699a15062

View File

@ -153,8 +153,7 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return None return None
def extra_conds(self, **kwargs): def concat_cond(self, **kwargs):
out = {}
if len(self.concat_keys) > 0: if len(self.concat_keys) > 0:
cond_concat = [] cond_concat = []
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
@ -193,7 +192,14 @@ class BaseModel(torch.nn.Module):
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise)) cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data) return data
return None
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
if concat_cond is not None:
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
adm = self.encode_adm(**kwargs) adm = self.encode_adm(**kwargs)
if adm is not None: if adm is not None:
@ -523,9 +529,7 @@ class SD_X4Upscaler(BaseModel):
return out return out
class IP2P: class IP2P:
def extra_conds(self, **kwargs): def concat_cond(self, **kwargs):
out = {}
image = kwargs.get("concat_latent_image", None) image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None) noise = kwargs.get("noise", None)
device = kwargs["device"] device = kwargs["device"]
@ -537,18 +541,15 @@ class IP2P:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0]) image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel): class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL): class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)