diff --git a/comfy/samplers.py b/comfy/samplers.py index 9c50d712b..b7ef6b963 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -277,6 +277,8 @@ class KSamplerX0Inpaint(torch.nn.Module): self.inner_model = model def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: + if "denoise_mask_function" in model_options: + denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask) latent_mask = 1. - denoise_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py new file mode 100644 index 000000000..48c95602f --- /dev/null +++ b/comfy_extras/nodes_differential_diffusion.py @@ -0,0 +1,97 @@ +# code adapted from https://github.com/exx8/differential-diffusion + +import torch +import inspect + +class DifferentialDiffusion(): + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply" + CATEGORY = "_for_testing" + INIT = False + + @classmethod + def IS_CHANGED(s, *args, **kwargs): + DifferentialDiffusion.INIT = s.INIT = True + return "" + + def __init__(self) -> None: + DifferentialDiffusion.INIT = False + self.sigmas: torch.Tensor = None + self.thresholds: torch.Tensor = None + self.mask_i = None + self.valid_sigmas = False + self.varying_sigmas_samplers = ["dpmpp_2s", "dpmpp_sde", "dpm_2", "heun", "restart"] + + def apply(self, model): + model = model.clone() + model.model_options["denoise_mask_function"] = self.forward + return (model,) + + def init_sigmas(self, sigma: torch.Tensor, denoise_mask: torch.Tensor): + self.__init__() + self.sigmas, sampler = find_outer_instance("sigmas", callback=get_sigmas_and_sampler) or (None, "") + self.valid_sigmas = not ("sample_" not in sampler or any(s in sampler for s in self.varying_sigmas_samplers)) or "generic" in sampler + if self.sigmas is None: + self.sigmas = sigma[:1].repeat(2) + self.sigmas[-1].zero_() + self.sigmas_min = self.sigmas.min() + self.sigmas_max = self.sigmas.max() + self.thresholds = torch.linspace(1, 0, self.sigmas.shape[0], dtype=sigma.dtype, device=sigma.device) + self.thresholds_min_len = self.thresholds.shape[0] - 1 + if self.valid_sigmas: + thresholds = self.thresholds[:-1].reshape(-1, 1, 1, 1, 1) + mask = denoise_mask.unsqueeze(0) + mask = (mask >= thresholds).to(denoise_mask.dtype) + self.mask_i = iter(mask) + + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor): + if self.sigmas is None or DifferentialDiffusion.INIT: + self.init_sigmas(sigma, denoise_mask) + if self.valid_sigmas: + try: + return next(self.mask_i) + except StopIteration: + self.valid_sigmas = False + if self.thresholds_min_len > 1: + nearest_idx = (self.sigmas - sigma[0]).abs().argmin() + if not self.thresholds_min_len > nearest_idx: + nearest_idx = -2 + threshold = self.thresholds[nearest_idx] + else: + threshold = (sigma[0] - self.sigmas_min) / (self.sigmas_max - self.sigmas_min) + return (denoise_mask >= threshold).to(denoise_mask.dtype) + +def get_sigmas_and_sampler(frame, target): + found = frame.f_locals[target] + if isinstance(found, torch.Tensor) and found[-1] < 0.1: + return found, frame.f_code.co_name + return False + +def find_outer_instance(target: str, target_type=None, callback=None): + frame = inspect.currentframe() + i = 0 + while frame and i < 100: + if target in frame.f_locals: + if callback is not None: + res = callback(frame, target) + if res: + return res + else: + found = frame.f_locals[target] + if isinstance(found, target_type): + return found + frame = frame.f_back + i += 1 + return None + + +NODE_CLASS_MAPPINGS = { + "DifferentialDiffusion": DifferentialDiffusion, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "DifferentialDiffusion": "Differential Diffusion", +} diff --git a/nodes.py b/nodes.py index 1846797b1..2cfea1a73 100644 --- a/nodes.py +++ b/nodes.py @@ -1961,6 +1961,7 @@ def init_custom_nodes(): "nodes_photomaker.py", "nodes_cond.py", "nodes_stable_cascade.py", + "nodes_differential_diffusion.py", ] for node_file in extras_files: