From 90f349f93df3083a507854d7fc7c3e1bb9014e24 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 12 Jan 2025 03:09:20 -0500 Subject: [PATCH 1/2] Add res_multistep sampler from the cosmos code. This sampler should work with all models. --- comfy/k_diffusion/res.py | 258 ++++++++++++++++++++++++++++++++++ comfy/k_diffusion/sampling.py | 18 +++ comfy/samplers.py | 2 +- 3 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 comfy/k_diffusion/res.py diff --git a/comfy/k_diffusion/res.py b/comfy/k_diffusion/res.py new file mode 100644 index 00000000..6caedec3 --- /dev/null +++ b/comfy/k_diffusion/res.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied from Nvidia Cosmos code. + +import torch +from torch import Tensor +from typing import Callable, List, Tuple, Optional, Any +import math +from tqdm.auto import trange + + +def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + ndims1 = x.ndim + ndims2 = y.ndim + + if ndims1 < ndims2: + x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) + elif ndims2 < ndims1: + y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) + + return x, y + + +def batch_mul(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x * y + + +def phi1(t: torch.Tensor) -> torch.Tensor: + """ + Compute the first order phi function: (exp(t) - 1) / t. + + Args: + t: Input tensor. + + Returns: + Tensor: Result of phi1 function. + """ + input_dtype = t.dtype + t = t.to(dtype=torch.float32) + return (torch.expm1(t) / t).to(dtype=input_dtype) + + +def phi2(t: torch.Tensor) -> torch.Tensor: + """ + Compute the second order phi function: (phi1(t) - 1) / t. + + Args: + t: Input tensor. + + Returns: + Tensor: Result of phi2 function. + """ + input_dtype = t.dtype + t = t.to(dtype=torch.float32) + return ((phi1(t) - 1.0) / t).to(dtype=input_dtype) + + +def res_x0_rk2_step( + x_s: torch.Tensor, + t: torch.Tensor, + s: torch.Tensor, + x0_s: torch.Tensor, + s1: torch.Tensor, + x0_s1: torch.Tensor, +) -> torch.Tensor: + """ + Perform a residual-based 2nd order Runge-Kutta step. + + Args: + x_s: Current state tensor. + t: Target time tensor. + s: Current time tensor. + x0_s: Prediction at current time. + s1: Intermediate time tensor. + x0_s1: Prediction at intermediate time. + + Returns: + Tensor: Updated state tensor. + + Raises: + AssertionError: If step size is too small. + """ + s = -torch.log(s) + t = -torch.log(t) + m = -torch.log(s1) + + dt = t - s + assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" + assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" + + c2 = (m - s) / dt + phi1_val, phi2_val = phi1(-dt), phi2(-dt) + + # Handle edge case where t = s = m + b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) + b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) + + return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1)) + + +def reg_x0_euler_step( + x_s: torch.Tensor, + s: torch.Tensor, + t: torch.Tensor, + x0_s: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a regularized Euler step based on x0 prediction. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_s: Prediction at current time. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current prediction. + """ + coef_x0 = (s - t) / s + coef_xs = t / s + return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s + + +def order2_fn( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + impl the second order multistep method in https://arxiv.org/pdf/2308.02157 + Adams Bashforth approach! + """ + if x0_preds: + x0_s1, s1 = x0_preds[0] + x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) + else: + x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] + return x_t, [(x0_s, s)] + + +class SolverConfig: + is_multi: bool = True + rk: str = "2mid" + multistep: str = "2ab" + s_churn: float = 0.0 + s_t_max: float = float("inf") + s_t_min: float = 0.0 + s_noise: float = 1.0 + + +def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any, disable=None) -> Any: + """ + Implements a for loop with a function. + + Args: + lower: Lower bound of the loop (inclusive). + upper: Upper bound of the loop (exclusive). + body_fun: Function to be applied in each iteration. + init_val: Initial value for the loop. + + Returns: + The final result after all iterations. + """ + val = init_val + for i in trange(lower, upper, disable=disable): + val = body_fun(i, val) + return val + + +def differential_equation_solver( + x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + sigmas_L: torch.Tensor, + solver_cfg: SolverConfig, + noise_sampler, + callback=None, + disable=None, +) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Creates a differential equation solver function. + + Args: + x0_fn: Function to compute x0 prediction. + sigmas_L: Tensor of sigma values with shape [L,]. + solver_cfg: Configuration for the solver. + + Returns: + A function that solves the differential equation. + """ + num_step = len(sigmas_L) - 1 + + # if solver_cfg.is_multi: + # update_step_fn = get_multi_step_fn(solver_cfg.multistep) + # else: + # update_step_fn = get_runge_kutta_fn(solver_cfg.rk) + update_step_fn = order2_fn + + eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1) + + def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor: + """ + Samples from the differential equation. + + Args: + input_xT_B_StateShape: Input tensor with shape [B, StateShape]. + + Returns: + Output tensor with shape [B, StateShape]. + """ + ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float32) + + def step_fn( + i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]] + ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: + input_x_B_StateShape, x0_preds = state + sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1] + + if sigma_next_0 == 0: + output_x_B_StateShape = x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) + else: + # algorithm 2: line 4-6 + if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max and eta > 0: + hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0 + input_x_B_StateShape = input_x_B_StateShape + ( + hat_sigma_cur_0**2 - sigma_cur_0**2 + ).sqrt() * solver_cfg.s_noise * noise_sampler(sigma_cur_0, sigma_next_0) # torch.randn_like(input_x_B_StateShape) + sigma_cur_0 = hat_sigma_cur_0 + + if solver_cfg.is_multi: + x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) + output_x_B_StateShape, x0_preds = update_step_fn( + input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds + ) + else: + output_x_B_StateShape, x0_preds = update_step_fn( + input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn + ) + + if callback is not None: + callback({'x': input_x_B_StateShape, 'i': i_th, 'sigma': sigma_cur_0, 'sigma_hat': sigma_cur_0, 'denoised': x0_pred_B_StateShape}) + + return output_x_B_StateShape, x0_preds + + x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None], disable=disable) + return x_at_eps + + return sample_fn diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 413ad5aa..3a98e6a7 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -8,6 +8,7 @@ from tqdm.auto import trange, tqdm from . import utils from . import deis +from . import res import comfy.model_patcher import comfy.model_sampling @@ -1265,3 +1266,20 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis x = denoised + denoised_mix + torch.exp(-h) * x old_uncond_denoised = uncond_denoised return x + +@torch.no_grad() +def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + + x0_func = lambda x, sigma: model(x, sigma, **extra_args) + + solver_cfg = res.SolverConfig() + solver_cfg.s_churn = s_churn + solver_cfg.s_t_max = s_tmax + solver_cfg.s_t_min = s_tmin + solver_cfg.s_noise = s_noise + + x = res.differential_equation_solver(x0_func, sigmas, solver_cfg, noise_sampler, callback=callback, disable=disable)(x) + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index 5cc33a7d..fa176c6d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -687,7 +687,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis"] + "ipndm", "ipndm_v", "deis", "res_multistep"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): From 1f1c7b7b5673fac3d3d38a1291ed1171f6cdc3eb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 13 Jan 2025 03:52:37 -0500 Subject: [PATCH 2/2] Remove useless code. --- comfy/ldm/cosmos/cosmos_tokenizer/utils.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py index 64dd5e28..3af8d0d0 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py @@ -17,7 +17,7 @@ from typing import Any import torch -from einops import pack, rearrange, unpack +from einops import rearrange import comfy.ops @@ -98,14 +98,6 @@ def default(*args): return None -def pack_one(t, pattern): - return pack([t], pattern) - - -def unpack_one(t, ps, pattern): - return unpack(t, ps, pattern)[0] - - def round_ste(z: torch.Tensor) -> torch.Tensor: """Round with straight through gradients.""" zhat = z.round()