From b699a15062e01230663964961b23c7ddf7c6c826 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 19 Nov 2024 02:34:35 -0500 Subject: [PATCH] Refactor inpaint/ip2p code. --- comfy/model_base.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index f2833168..7e92ca10 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -153,8 +153,7 @@ class BaseModel(torch.nn.Module): def encode_adm(self, **kwargs): return None - def extra_conds(self, **kwargs): - out = {} + def concat_cond(self, **kwargs): if len(self.concat_keys) > 0: cond_concat = [] denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) @@ -193,7 +192,14 @@ class BaseModel(torch.nn.Module): elif ck == "masked_image": cond_concat.append(self.blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) - out['c_concat'] = comfy.conds.CONDNoiseShape(data) + return data + return None + + def extra_conds(self, **kwargs): + out = {} + concat_cond = self.concat_cond(**kwargs) + if concat_cond is not None: + out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond) adm = self.encode_adm(**kwargs) if adm is not None: @@ -523,9 +529,7 @@ class SD_X4Upscaler(BaseModel): return out class IP2P: - def extra_conds(self, **kwargs): - out = {} - + def concat_cond(self, **kwargs): image = kwargs.get("concat_latent_image", None) noise = kwargs.get("noise", None) device = kwargs["device"] @@ -537,18 +541,15 @@ class IP2P: image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.resize_to_batch_size(image, noise.shape[0]) + return self.process_ip2p_image_in(image) - out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image)) - adm = self.encode_adm(**kwargs) - if adm is not None: - out['y'] = comfy.conds.CONDRegular(adm) - return out class SD15_instructpix2pix(IP2P, BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.process_ip2p_image_in = lambda image: image + class SDXL_instructpix2pix(IP2P, SDXL): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device)