mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-13 22:31:08 +00:00
Refactor sampling code for more advanced sampler nodes.
This commit is contained in:
parent
6c6a39251f
commit
57753c964a
@ -52,9 +52,16 @@ def convert_cond(cond):
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(positive, negative, dtype):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
||||
def get_additional_models(conds, dtype):
|
||||
"""loads additional models in conditioning"""
|
||||
cnets = []
|
||||
gligen = []
|
||||
|
||||
for i in range(len(conds)):
|
||||
cnets += get_models_from_cond(conds[i], "control")
|
||||
gligen += get_models_from_cond(conds[i], "gligen")
|
||||
|
||||
control_nets = set(cnets)
|
||||
|
||||
inference_memory = 0
|
||||
control_models = []
|
||||
@ -62,7 +69,6 @@ def get_additional_models(positive, negative, dtype):
|
||||
control_models += m.get_models()
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
return models, inference_memory
|
||||
@ -73,24 +79,25 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||
def prepare_sampling(model, noise_shape, conds, noise_mask):
|
||||
device = model.load_device
|
||||
positive = convert_cond(positive)
|
||||
negative = convert_cond(negative)
|
||||
for i in range(len(conds)):
|
||||
conds[i] = convert_cond(conds[i])
|
||||
|
||||
if noise_mask is not None:
|
||||
noise_mask = prepare_mask(noise_mask, noise_shape, device)
|
||||
|
||||
real_model = None
|
||||
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, positive, negative, noise_mask, models
|
||||
return real_model, conds, noise_mask, models
|
||||
|
||||
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
|
||||
real_model, conds_copy, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask)
|
||||
positive_copy, negative_copy = conds_copy
|
||||
|
||||
noise = noise.to(model.load_device)
|
||||
latent_image = latent_image.to(model.load_device)
|
||||
@ -105,14 +112,19 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
|
||||
real_model, conds, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask)
|
||||
noise = noise.to(model.load_device)
|
||||
latent_image = latent_image.to(model.load_device)
|
||||
sigmas = sigmas.to(model.load_device)
|
||||
|
||||
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = comfy.samplers.sample(real_model, noise, conds[0], conds[1], cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
cleanup_additional_models(models)
|
||||
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||
|
||||
control_cleanup = []
|
||||
for i in range(len(conds)):
|
||||
control_cleanup += get_models_from_cond(conds[i], "control")
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
return samples
|
||||
|
||||
|
@ -260,11 +260,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
return cfg_result
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, cond_scale=1.0):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
self.cond_scale = cond_scale
|
||||
def apply_model(self, x, timestep, conds, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model, x, timestep, conds.get("negative", None), conds.get("positive", None), self.cond_scale, model_options=model_options, seed=seed)
|
||||
return out
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
@ -274,13 +275,13 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.sigmas = sigmas
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||
def forward(self, x, sigma, conds, denoise_mask, model_options={}, seed=None):
|
||||
if denoise_mask is not None:
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
||||
out = self.inner_model(x, sigma, conds=conds, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
return out
|
||||
@ -568,45 +569,56 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
def wrap_model(model):
|
||||
model_denoise = CFGNoisePredictor(model)
|
||||
return model_denoise
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
positive = positive[:]
|
||||
negative = negative[:]
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||
for k in conds:
|
||||
conds[k] = conds[k][:]
|
||||
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
|
||||
|
||||
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
|
||||
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
|
||||
for k in conds:
|
||||
calculate_start_end_timesteps(model, conds[k])
|
||||
|
||||
model_wrap = wrap_model(model)
|
||||
if hasattr(model, 'extra_conds'):
|
||||
for k in conds:
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
|
||||
calculate_start_end_timesteps(model, negative)
|
||||
calculate_start_end_timesteps(model, positive)
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for k in conds:
|
||||
for c in conds[k]:
|
||||
for kk in conds:
|
||||
if k != kk:
|
||||
create_cond_with_same_area_if_none(conds[kk], c)
|
||||
|
||||
for k in conds:
|
||||
pre_run_control(model, conds[k])
|
||||
|
||||
if "positive" in conds:
|
||||
positive = conds["positive"]
|
||||
for k in conds:
|
||||
if k != "positive":
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
return conds
|
||||
|
||||
|
||||
def sample_advanced(model, noise, conds, guider_class, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||
latent_image = model.process_latent_in(latent_image)
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
conds = process_conds(model, noise, conds, device, latent_image, denoise_mask, seed)
|
||||
model_wrap = guider_class(model)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
create_cond_with_same_area_if_none(negative, c)
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
pre_run_control(model, negative + positive)
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||
extra_args = {"conds": conds, "model_options": model_options, "seed":seed}
|
||||
|
||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
return sample_advanced(model, noise, {"positive": positive, "negative": negative}, lambda a: CFGNoisePredictor(a, cfg), device, sampler, sigmas, model_options, latent_image, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user