diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index 5e2d7364..e00ffd3f 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -78,7 +78,7 @@ class DDIMSampler(object): dynamic_threshold=None, ucg_schedule=None, denoise_function=None, - cond_concat=None, + extra_args=None, to_zero=True, end_step=None, **kwargs @@ -101,7 +101,7 @@ class DDIMSampler(object): dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=denoise_function, - cond_concat=cond_concat, + extra_args=extra_args, to_zero=to_zero, end_step=end_step ) @@ -174,7 +174,7 @@ class DDIMSampler(object): dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=None, - cond_concat=None + extra_args=None ) return samples, intermediates @@ -185,7 +185,7 @@ class DDIMSampler(object): mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None): + ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -225,7 +225,7 @@ class DDIMSampler(object): corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat) + dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) @@ -249,11 +249,11 @@ class DDIMSampler(object): def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None, denoise_function=None, cond_concat=None): + dynamic_threshold=None, denoise_function=None, extra_args=None): b, *_, device = *x.shape, x.device if denoise_function is not None: - model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat) + model_output = denoise_function(self.model.apply_model, x, t, **extra_args) elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: model_output = self.model.apply_model(x, t, c) else: diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index 42ed2add..6af96124 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module): self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}): if self.conditioning_key is None: - out = self.diffusion_model(x, t, control=control) + out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t, control=control) + out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'crossattn': if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) @@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module): # TorchScript changes names of the arguments # with argument cc defined as context=cc scripted model will produce # an error: RuntimeError: forward() is missing value for argument 'argument_3'. - out = self.scripted_diffusion_model(x, t, cc, control=control) + out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options) else: - out = self.diffusion_model(x, t, context=cc, control=control) + out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, control=control) + out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'hybrid-adm': assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) + out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'crossattn-adm': assert c_adm is not None cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) + out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'adm': cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc, control=control) + out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options) else: raise NotImplementedError() diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 23b04734..07553627 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention import model_management +from . import tomesd if model_management.xformers_enabled(): import xformers @@ -504,12 +505,22 @@ class BasicTransformerBlock(nn.Module): self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint - def forward(self, x, context=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + def forward(self, x, context=None, transformer_options={}): + return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) - def _forward(self, x, context=None): - x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x - x = self.attn2(self.norm2(x), context=context) + x + def _forward(self, x, context=None, transformer_options={}): + n = self.norm1(x) + if "tomesd" in transformer_options: + m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) + n = u(self.attn1(m(n), context=context if self.disable_self_attn else None)) + else: + n = self.attn1(n, context=context if self.disable_self_attn else None) + + x += n + n = self.norm2(x) + n = self.attn2(n, context=context) + + x += n x = self.ff(self.norm3(x)) + x return x @@ -557,7 +568,7 @@ class SpatialTransformer(nn.Module): self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear - def forward(self, x, context=None): + def forward(self, x, context=None, transformer_options={}): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] @@ -570,7 +581,7 @@ class SpatialTransformer(nn.Module): if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i]) + x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 09ab1a06..7b2f5b53 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): support it as an extra input. """ - def forward(self, x, emb, context=None): + def forward(self, x, emb, context=None, transformer_options={}): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): - x = layer(x, context) + x = layer(x, context, transformer_options) else: x = layer(x) return x @@ -753,7 +753,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -762,6 +762,7 @@ class UNetModel(nn.Module): :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ + transformer_options["original_shape"] = list(x.shape) assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -775,13 +776,13 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): - h = module(h, emb, context) + h = module(h, emb, context, transformer_options) if control is not None and 'input' in control and len(control['input']) > 0: ctrl = control['input'].pop() if ctrl is not None: h += ctrl hs.append(h) - h = self.middle_block(h, emb, context) + h = self.middle_block(h, emb, context, transformer_options) if control is not None and 'middle' in control and len(control['middle']) > 0: h += control['middle'].pop() @@ -793,7 +794,7 @@ class UNetModel(nn.Module): hsp += ctrl h = th.cat([h, hsp], dim=1) del hsp - h = module(h, emb, context) + h = module(h, emb, context, transformer_options) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py new file mode 100644 index 00000000..1eafcd0a --- /dev/null +++ b/comfy/ldm/modules/tomesd.py @@ -0,0 +1,117 @@ + + +import torch +from typing import Tuple, Callable +import math + +def do_nothing(x: torch.Tensor, mode:str=None): + return x + + +def bipartite_soft_matching_random2d(metric: torch.Tensor, + w: int, h: int, sx: int, sy: int, r: int, + no_rand: bool = False) -> Tuple[Callable, Callable]: + """ + Partitions the tokens into src and dst and merges r tokens from src to dst. + Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. + + Args: + - metric [B, N, C]: metric to use for similarity + - w: image width in tokens + - h: image height in tokens + - sx: stride in the x dimension for dst, must divide w + - sy: stride in the y dimension for dst, must divide h + - r: number of tokens to remove (by merging) + - no_rand: if true, disable randomness (use top left corner only) + """ + B, N, _ = metric.shape + + if r <= 0: + return do_nothing, do_nothing + + with torch.no_grad(): + + hsy, wsx = h // sy, w // sx + + # For each sy by sx kernel, randomly assign one token to be dst and the rest src + idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device) + + if no_rand: + rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) + else: + rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) + + idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) + idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) + rand_idx = idx_buffer.argsort(dim=1) + + num_dst = int((1 / (sx*sy)) * N) + a_idx = rand_idx[:, num_dst:, :] # src + b_idx = rand_idx[:, :num_dst, :] # dst + + def split(x): + C = x.shape[-1] + src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) + return src, dst + + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], r) + + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) + + def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: + src, dst = split(x) + n, t1, c = src.shape + + unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) + dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor) -> torch.Tensor: + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + _, _, c = unm.shape + + src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) + + # Combine back to the original shape + out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) + out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) + + return out + + return merge, unmerge + + +def get_functions(x, ratio, original_shape): + b, c, original_h, original_w = original_shape + original_tokens = original_h * original_w + downsample = int(math.sqrt(original_tokens // x.shape[1])) + stride_x = 2 + stride_y = 2 + max_downsample = 1 + + if downsample <= max_downsample: + w = original_w // downsample + h = original_h // downsample + r = int(x.shape[1] * ratio) + no_rand = False + m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) + return m, u + + nothing = lambda y: y + return nothing, nothing diff --git a/comfy/samplers.py b/comfy/samplers.py index 66218f88..15e78bbd 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module): #The main sampling function shared by all the samplers #Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None): +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 @@ -104,7 +104,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con out['c_concat'] = [torch.cat(c_concat)] return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in): + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -169,6 +169,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if control is not None: c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) + if 'transformer_options' in model_options: + c['transformer_options'] = model_options['transformer_options'] + output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x @@ -192,7 +195,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) return uncond + (cond - uncond) * cond_scale @@ -209,8 +212,8 @@ class CFGNoisePredictor(torch.nn.Module): super().__init__() self.inner_model = model self.alphas_cumprod = model.alphas_cumprod - def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat) + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options) return out @@ -218,11 +221,11 @@ class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat) + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options) if denoise_mask is not None: out *= denoise_mask @@ -330,7 +333,7 @@ class KSampler: "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] - def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): + def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model self.model_denoise = CFGNoisePredictor(self.model) if self.model.parameterization == "v": @@ -350,6 +353,7 @@ class KSampler: self.sigma_max=float(self.model_wrap.sigma_max) self.set_steps(steps, denoise) self.denoise = denoise + self.model_options = model_options def _calculate_sigmas(self, steps): sigmas = None @@ -418,7 +422,7 @@ class KSampler: else: precision_scope = contextlib.nullcontext - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} cond_concat = None if hasattr(self.model, 'concat_keys'): @@ -467,7 +471,7 @@ class KSampler: x_T=z_enc, x0=latent_image, denoise_function=sampling_function, - cond_concat=cond_concat, + extra_args=extra_args, mask=noise_mask, to_zero=sigmas[-1]==0, end_step=sigmas.shape[0] - 1) diff --git a/comfy/sd.py b/comfy/sd.py index 2e1ae840..2a38ceb1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,5 +1,6 @@ import torch import contextlib +import copy import sd1_clip import sd2_clip @@ -274,12 +275,20 @@ class ModelPatcher: self.model = model self.patches = [] self.backup = {} + self.model_options = {"transformer_options":{}} def clone(self): n = ModelPatcher(self.model) n.patches = self.patches[:] + n.model_options = copy.deepcopy(self.model_options) return n + def set_model_tomesd(self, ratio): + self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} + + def model_dtype(self): + return self.model.diffusion_model.dtype + def add_patches(self, patches, strength=1.0): p = {} model_sd = self.model.state_dict() diff --git a/nodes.py b/nodes.py index b422f2cb..a10cf5a0 100644 --- a/nodes.py +++ b/nodes.py @@ -254,6 +254,22 @@ class LoraLoader: model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) +class TomePatchModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, ratio): + m = model.clone() + m.set_model_tomesd(ratio) + return (m, ) + class VAELoader: @classmethod def INPUT_TYPES(s): @@ -646,7 +662,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, model_management.load_controlnet_gpu(control_net_models) if sampler_name in comfy.samplers.KSampler.SAMPLERS: - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) else: #other samplers pass @@ -1016,6 +1032,7 @@ NODE_CLASS_MAPPINGS = { "CLIPVisionLoader": CLIPVisionLoader, "VAEDecodeTiled": VAEDecodeTiled, "VAEEncodeTiled": VAEEncodeTiled, + "TomePatchModel": TomePatchModel, } def load_custom_node(module_path): diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index ff9227d2..7e668826 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -101,7 +101,7 @@ app.registerExtension({ callback: () => convertToWidget(this, w), }); } else { - const config = nodeData?.input?.required[w.name] || [w.type, w.options || {}]; + const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}]; if (isConvertableWidget(w, config)) { toInput.push({ content: `Convert ${w.name} to input`,