mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Implement DDIM sampler.
This commit is contained in:
parent
218f64315d
commit
f04dc2c2f4
@ -22,11 +22,15 @@ class DDIMSampler(object):
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose)
|
||||
|
||||
def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = torch.tensor(ddim_timesteps)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
@ -52,6 +56,58 @@ class DDIMSampler(object):
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_custom(self,
|
||||
ddim_timesteps,
|
||||
conditioning,
|
||||
callback=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
denoise_function=None,
|
||||
cond_concat=None,
|
||||
to_zero=True,
|
||||
end_step=None,
|
||||
**kwargs
|
||||
):
|
||||
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
|
||||
samples, intermediates = self.ddim_sampling(conditioning, x_T.shape,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule,
|
||||
denoise_function=denoise_function,
|
||||
cond_concat=cond_concat,
|
||||
to_zero=to_zero,
|
||||
end_step=end_step
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
@ -116,7 +172,9 @@ class DDIMSampler(object):
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule
|
||||
ucg_schedule=ucg_schedule,
|
||||
denoise_function=None,
|
||||
cond_concat=None
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@ -127,7 +185,7 @@ class DDIMSampler(object):
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||
ucg_schedule=None):
|
||||
ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
@ -142,11 +200,11 @@ class DDIMSampler(object):
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
@ -167,7 +225,7 @@ class DDIMSampler(object):
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat)
|
||||
img, pred_x0 = outs
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
@ -176,16 +234,27 @@ class DDIMSampler(object):
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
if to_zero:
|
||||
img = pred_x0
|
||||
else:
|
||||
if ddim_use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
img /= sqrt_alphas_cumprod[index - 1]
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
dynamic_threshold=None, denoise_function=None, cond_concat=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
if denoise_function is not None:
|
||||
model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat)
|
||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
@ -299,7 +368,7 @@ class DDIMSampler(object):
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
@ -311,8 +380,12 @@ class DDIMSampler(object):
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||
if max_denoise:
|
||||
noise_multiplier = 1.0
|
||||
else:
|
||||
noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
|
@ -4,6 +4,8 @@ from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import contextlib
|
||||
import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
@ -234,6 +236,14 @@ def simple_scheduler(model, steps):
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def ddim_scheduler(model, steps):
|
||||
sigs = []
|
||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
|
||||
for x in range(len(ddim_timesteps) - 1, -1, -1):
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ddim_timesteps[x])))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
@ -310,10 +320,10 @@ def apply_control_net_to_equal_area(conds, uncond):
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple"]
|
||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
|
||||
"sample_lms", "sample_dpm_fast", "sample_dpm_adaptive", "sample_dpmpp_2s_ancestral", "sample_dpmpp_sde",
|
||||
"sample_dpmpp_2m", "uni_pc", "uni_pc_bh2"]
|
||||
"sample_dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
||||
self.model = model
|
||||
@ -350,6 +360,8 @@ class KSampler:
|
||||
sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
|
||||
elif self.scheduler == "simple":
|
||||
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device)
|
||||
elif self.scheduler == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
|
||||
@ -403,6 +415,7 @@ class KSampler:
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
||||
|
||||
cond_concat = None
|
||||
if hasattr(self.model, 'concat_keys'):
|
||||
cond_concat = []
|
||||
for ck in self.model.concat_keys:
|
||||
@ -428,6 +441,32 @@ class KSampler:
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask)
|
||||
elif self.sampler == "uni_pc_bh2":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2')
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
sampler = DDIMSampler(self.model)
|
||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
conditioning=positive,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg,
|
||||
unconditional_conditioning=negative,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
denoise_function=sampling_function,
|
||||
cond_concat=cond_concat,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1)
|
||||
|
||||
else:
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
self.model_k.latent_image = latent_image
|
||||
|
Loading…
Reference in New Issue
Block a user