diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 0f7cc4ca..d37d7dd9 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -70,8 +70,14 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): return sigma_down, sigma_up -def default_noise_sampler(x): - return lambda sigma, sigma_next: torch.randn_like(x) +def default_noise_sampler(x, seed=None): + if seed is not None: + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + else: + generator = None + + return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) class BatchedBrownianTree: @@ -168,7 +174,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -189,7 +196,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -290,7 +298,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -318,7 +327,8 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -591,7 +601,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() @@ -625,7 +636,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 lambda_fn = lambda sigma: ((1-sigma)/sigma).log() @@ -1153,7 +1165,8 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler temp = [0] def post_cfg_function(args): @@ -1179,7 +1192,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler temp = [0] def post_cfg_function(args):