diff --git a/comfy/model_base.py b/comfy/model_base.py index 2fb4b1453..eec70d5de 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -59,6 +59,7 @@ class ModelType(Enum): FLOW = 6 V_PREDICTION_CONTINUOUS = 7 FLUX = 8 + IMG_TO_IMG = 9 from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV @@ -89,6 +90,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLUX: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingFlux + elif model_type == ModelType.IMG_TO_IMG: + c = comfy.model_sampling.IMG_TO_IMG class ModelSampling(s, c): pass @@ -613,7 +616,7 @@ class Lotus(BaseModel): out['y'] = comfy.conds.CONDRegular(task_emb) return out - def __init__(self, model_config, model_type=ModelType.EPS, device=None): + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None): super().__init__(model_config, model_type, device=device) class StableCascade_C(BaseModel): diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index ff27b09a8..b79af1e92 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -69,6 +69,15 @@ class CONST: sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1)) return latent / (1.0 - sigma) +class X0(EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + +class IMG_TO_IMG(X0): + def calculate_input(self, sigma, noise): + return noise + + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): super().__init__() diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 2b805c1ee..71a652ffa 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -20,14 +20,6 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class X0(comfy.model_sampling.EPS): - def calculate_denoised(self, sigma, model_output, model_input): - return model_output - -class Lotus(X0): - def calculate_input(self, sigma, noise): - return noise - class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 @@ -60,7 +52,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0", "lotus"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -81,9 +73,9 @@ class ModelSamplingDiscrete: sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled elif sampling == "x0": - sampling_type = X0 - elif sampling == "lotus": - sampling_type = Lotus + sampling_type = comfy.model_sampling.X0 + elif sampling == "img_to_img": + sampling_type = comfy.model_sampling.IMG_TO_IMG class ModelSamplingAdvanced(sampling_base, sampling_type): pass