From 0542088ef895b4825df80fd3babf91513441af65 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 Apr 2024 00:48:42 -0400 Subject: [PATCH] Refactor sampler code for more advanced sampler nodes part 2. --- comfy/sample.py | 102 +++--------------------------- comfy/sampler_helpers.py | 76 +++++++++++++++++++++++ comfy/samplers.py | 130 ++++++++++++++++++++++++--------------- 3 files changed, 165 insertions(+), 143 deletions(-) create mode 100644 comfy/sampler_helpers.py diff --git a/comfy/sample.py b/comfy/sample.py index 3c65d0a8..e51bd67d 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,10 +1,9 @@ import torch import comfy.model_management import comfy.samplers -import comfy.conds import comfy.utils -import math import numpy as np +import logging def prepare_noise(latent_image, seed, noise_inds=None): """ @@ -25,106 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None): noises = torch.cat(noises, axis=0) return noises -def prepare_mask(noise_mask, shape, device): - """ensures noise mask is of proper dimensions""" - noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") - noise_mask = torch.cat([noise_mask] * shape[1], dim=1) - noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) - noise_mask = noise_mask.to(device) - return noise_mask - -def get_models_from_cond(cond, model_type): - models = [] - for c in cond: - if model_type in c: - models += [c[model_type]] - return models - -def convert_cond(cond): - out = [] - for c in cond: - temp = c[1].copy() - model_conds = temp.get("model_conds", {}) - if c[0] is not None: - model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove - temp["cross_attn"] = c[0] - temp["model_conds"] = model_conds - out.append(temp) - return out - -def get_additional_models(conds, dtype): - """loads additional models in conditioning""" - cnets = [] - gligen = [] - - for i in range(len(conds)): - cnets += get_models_from_cond(conds[i], "control") - gligen += get_models_from_cond(conds[i], "gligen") - - control_nets = set(cnets) - - inference_memory = 0 - control_models = [] - for m in control_nets: - control_models += m.get_models() - inference_memory += m.inference_memory_requirements(dtype) - - gligen = [x[1] for x in gligen] - models = control_models + gligen - return models, inference_memory +def prepare_sampling(model, noise_shape, positive, negative, noise_mask): + logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed") + return model, positive, negative, noise_mask, [] def cleanup_additional_models(models): - """cleanup additional models that were loaded""" - for m in models: - if hasattr(m, 'cleanup'): - m.cleanup() - -def prepare_sampling(model, noise_shape, conds, noise_mask): - device = model.load_device - for i in range(len(conds)): - conds[i] = convert_cond(conds[i]) - - if noise_mask is not None: - noise_mask = prepare_mask(noise_mask, noise_shape, device) - - real_model = None - models, inference_memory = get_additional_models(conds, model.model_dtype()) - comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) - real_model = model.model - - return real_model, conds, noise_mask, models - + logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed") def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): - real_model, conds_copy, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask) - positive_copy, negative_copy = conds_copy + sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - noise = noise.to(model.load_device) - latent_image = latent_image.to(model.load_device) - - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(comfy.model_management.intermediate_device()) - - cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): - real_model, conds, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask) - noise = noise.to(model.load_device) - latent_image = latent_image.to(model.load_device) - sigmas = sigmas.to(model.load_device) - - samples = comfy.samplers.sample(real_model, noise, conds[0], conds[1], cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(comfy.model_management.intermediate_device()) - cleanup_additional_models(models) - - control_cleanup = [] - for i in range(len(conds)): - control_cleanup += get_models_from_cond(conds[i], "control") - - cleanup_additional_models(set(control_cleanup)) return samples - diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py new file mode 100644 index 00000000..a18abd9e --- /dev/null +++ b/comfy/sampler_helpers.py @@ -0,0 +1,76 @@ +import torch +import comfy.model_management +import comfy.conds + +def prepare_mask(noise_mask, shape, device): + """ensures noise mask is of proper dimensions""" + noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") + noise_mask = torch.cat([noise_mask] * shape[1], dim=1) + noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) + noise_mask = noise_mask.to(device) + return noise_mask + +def get_models_from_cond(cond, model_type): + models = [] + for c in cond: + if model_type in c: + models += [c[model_type]] + return models + +def convert_cond(cond): + out = [] + for c in cond: + temp = c[1].copy() + model_conds = temp.get("model_conds", {}) + if c[0] is not None: + model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove + temp["cross_attn"] = c[0] + temp["model_conds"] = model_conds + out.append(temp) + return out + +def get_additional_models(conds, dtype): + """loads additional models in conditioning""" + cnets = [] + gligen = [] + + for k in conds: + cnets += get_models_from_cond(conds[k], "control") + gligen += get_models_from_cond(conds[k], "gligen") + + control_nets = set(cnets) + + inference_memory = 0 + control_models = [] + for m in control_nets: + control_models += m.get_models() + inference_memory += m.inference_memory_requirements(dtype) + + gligen = [x[1] for x in gligen] + models = control_models + gligen + return models, inference_memory + +def cleanup_additional_models(models): + """cleanup additional models that were loaded""" + for m in models: + if hasattr(m, 'cleanup'): + m.cleanup() + + +def prepare_sampling(model, noise_shape, conds): + device = model.load_device + real_model = None + models, inference_memory = get_additional_models(conds, model.model_dtype()) + comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) + real_model = model.model + + return real_model, conds, models + +def cleanup_models(conds, models): + cleanup_additional_models(models) + + control_cleanup = [] + for k in conds: + control_cleanup += get_models_from_cond(conds[k], "control") + + cleanup_additional_models(set(control_cleanup)) diff --git a/comfy/samplers.py b/comfy/samplers.py index f18de200..57f5632c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -5,6 +5,7 @@ import collections from comfy import model_management import math import logging +import comfy.sampler_helpers def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) @@ -230,58 +231,45 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) +def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}): + if "sampler_cfg_function" in model_options: + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + cfg_result = x - model_options["sampler_cfg_function"](args) + else: + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) + + return cfg_result + #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: - uncond_ = None - else: - uncond_ = uncond + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: + uncond_ = None + else: + uncond_ = uncond + + conds = [cond, uncond_] + out = calc_cond_batch(model, conds, x, timestep, model_options) + return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options) - conds = [cond, uncond_] - - out = calc_cond_batch(model, conds, x, timestep, model_options) - cond_pred = out[0] - uncond_pred = out[1] - - if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} - cfg_result = x - model_options["sampler_cfg_function"](args) - else: - cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale - - for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, - "sigma": timestep, "model_options": model_options, "input": x} - cfg_result = fn(args) - - return cfg_result - -class CFGNoisePredictor(torch.nn.Module): - def __init__(self, model, cond_scale=1.0): - super().__init__() - self.inner_model = model - self.cond_scale = cond_scale - def apply_model(self, x, timestep, conds, model_options={}, seed=None): - out = sampling_function(self.inner_model, x, timestep, conds.get("negative", None), conds.get("positive", None), self.cond_scale, model_options=model_options, seed=seed) - return out - def forward(self, *args, **kwargs): - return self.apply_model(*args, **kwargs) - -class KSamplerX0Inpaint(torch.nn.Module): +class KSamplerX0Inpaint: def __init__(self, model, sigmas): - super().__init__() self.inner_model = model self.sigmas = sigmas - def forward(self, x, sigma, conds, denoise_mask, model_options={}, seed=None): + def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: if "denoise_mask_function" in model_options: denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask - out = self.inner_model(x, sigma, conds=conds, model_options=model_options, seed=seed) + out = self.inner_model(x, sigma, model_options=model_options, seed=seed) if denoise_mask is not None: out = out * denoise_mask + self.latent_image * latent_mask return out @@ -601,22 +589,66 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N return conds +class CFGGuider: + def __init__(self, model_patcher): + self.model_patcher = model_patcher + self.model_options = model_patcher.model_options + self.original_conds = {} + self.cfg = 1.0 -def sample_advanced(model, noise, conds, guider_class, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. - latent_image = model.process_latent_in(latent_image) + def set_conds(self, conds): + for k in conds: + self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) - conds = process_conds(model, noise, conds, device, latent_image, denoise_mask, seed) - model_wrap = guider_class(model) + def set_cfg(self, cfg): + self.cfg = cfg - extra_args = {"conds": conds, "model_options": model_options, "seed":seed} + def __call__(self, *args, **kwargs): + return self.predict_noise(*args, **kwargs) - samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) - return model.process_latent_out(samples.to(torch.float32)) + def predict_noise(self, x, timestep, model_options={}, seed=None): + return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) + + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = self.inner_model.process_latent_in(latent_image) + + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + + extra_args = {"model_options": self.model_options, "seed":seed} + + samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) + return self.inner_model.process_latent_out(samples.to(torch.float32)) + + def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + self.conds = {} + for k in self.original_conds: + self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) + + self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + device = self.model_patcher.load_device + + if denoise_mask is not None: + denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) + + noise = noise.to(device) + latent_image = latent_image.to(device) + sigmas = sigmas.to(device) + + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + + comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) + del self.inner_model + del self.conds + del self.loaded_models + return output def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - return sample_advanced(model, noise, {"positive": positive, "negative": negative}, lambda a: CFGNoisePredictor(a, cfg), device, sampler, sigmas, model_options, latent_image, denoise_mask, callback, disable_pbar, seed) + cfg_guider = CFGGuider(model) + cfg_guider.set_conds({"positive": positive, "negative": negative}) + cfg_guider.set_cfg(cfg) + return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] @@ -676,7 +708,7 @@ class KSampler: steps += 1 discard_penultimate_sigma = True - sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) + sigmas = calculate_sigmas_scheduler(self.model.model, self.scheduler, steps) if discard_penultimate_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])