From 3b19fc76e34692d779ceffe233e0a952cbcd20ab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 18 Mar 2025 05:09:25 -0400 Subject: [PATCH 01/37] Allow disabling pe in flux code for some other models. --- comfy/ldm/flux/math.py | 9 +++++---- comfy/ldm/flux/model.py | 7 +++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index c0cbd291..3e097817 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: q_shape = q.shape k_shape = k.shape - q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) - k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) + if pe is not None: + q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) + k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) + q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) + k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index cc34f758..ef4ba410 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -115,8 +115,11 @@ class Flux(nn.Module): vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) txt = self.txt_in(txt) - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) + if img_ids is not None: + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + else: + pe = None blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): From 11f1b41bab62ece770aa1d3aacc59a450e277b41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Mar 2025 16:19:50 -0400 Subject: [PATCH 02/37] Initial Hunyuan3Dv2 implementation. Supports the multiview, mini, turbo models and VAEs. --- comfy/latent_formats.py | 10 + comfy/ldm/hunyuan3d/model.py | 135 ++++++++ comfy/ldm/hunyuan3d/vae.py | 587 ++++++++++++++++++++++++++++++++ comfy/model_base.py | 16 + comfy/model_detection.py | 17 +- comfy/sd.py | 15 +- comfy/supported_models.py | 38 ++- comfy_extras/nodes_hunyuan3d.py | 410 ++++++++++++++++++++++ nodes.py | 1 + 9 files changed, 1225 insertions(+), 4 deletions(-) create mode 100644 comfy/ldm/hunyuan3d/model.py create mode 100644 comfy/ldm/hunyuan3d/vae.py create mode 100644 comfy_extras/nodes_hunyuan3d.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 622c1df5..556c3951 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -456,3 +456,13 @@ class Wan21(LatentFormat): latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean + +class Hunyuan3Dv2(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 0.9990943042622529 + +class Hunyuan3Dv2mini(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 1.0188137142395404 diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py new file mode 100644 index 00000000..4e18358f --- /dev/null +++ b/comfy/ldm/hunyuan3d/model.py @@ -0,0 +1,135 @@ +import torch +from torch import nn +from comfy.ldm.flux.layers import ( + DoubleStreamBlock, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +class Hunyuan3Dv2(nn.Module): + def __init__( + self, + in_channels=64, + context_in_dim=1536, + hidden_size=1024, + mlp_ratio=4.0, + num_heads=16, + depth=16, + depth_single_blocks=32, + qkv_bias=True, + guidance_embed=False, + image_model=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dtype = dtype + + if hidden_size % num_heads != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}" + ) + + self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead + self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None + ) + self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device) + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth) + ] + ) + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth_single_blocks) + ] + ) + self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): + x = x.movedim(-1, -2) + timestep = 1.0 - timestep + txt = context + img = self.latent_in(x) + + vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype)) + if self.guidance_in is not None: + if guidance is not None: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype)) + + txt = self.cond_in(txt) + pe = None + attn_mask = None + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.double_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + + img = img[:, txt.shape[1]:, ...] + img = self.final_layer(img, vec) + return img.movedim(-2, -1) * (-1.0) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py new file mode 100644 index 00000000..311c9b41 --- /dev/null +++ b/comfy/ldm/hunyuan3d/vae.py @@ -0,0 +1,587 @@ +# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py +# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from typing import Union, Tuple, List, Callable, Optional + +import numpy as np +from einops import repeat, rearrange +from tqdm import tqdm +import logging + +import comfy.ops +ops = comfy.ops.disable_weight_init + +def generate_dense_grid_points( + bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_resolution: int, + indexing: str = "ij", +): + length = bbox_max - bbox_min + num_cells = octree_resolution + + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + + +class VanillaVolumeDecoder: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + octree_resolution: int = None, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=octree_resolution, + indexing="ij" + ) + xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + + # 2. latents to 3d volume + batch_logits = [] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", + disable=not enable_pbar): + chunk_queries = xyz_samples[start: start + num_chunks, :] + chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=chunk_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1) + grid_logits = grid_logits.view((batch_size, *grid_size)).float() + + return grid_logits + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class CrossAttentionProcessor: + def __call__(self, attn, q, k, v): + out = F.scaled_dot_product_attention(q, k, v) + return out + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if self.drop_prob == 0. or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob, 3):0.3f}' + + +class MLP(nn.Module): + def __init__( + self, *, + width: int, + expand_ratio: int = 4, + output_width: int = None, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.c_fc = ops.Linear(width, width * expand_ratio) + self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width) + self.gelu = nn.GELU() + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.attn_processor = CrossAttentionProcessor() + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = self.attn_processor(self, q, k, v) + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + data_width: Optional[int] = None, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + kv_cache: bool = False, + ): + super().__init__() + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = ops.Linear(width, width, bias=qkv_bias) + self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadCrossAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.kv_cache = kv_cache + self.data = None + + def forward(self, x, data): + x = self.c_q(x) + if self.kv_cache: + if self.data is None: + self.data = self.c_kv(data) + logging.info('Save kv cache,this should be called only once for one mesh') + data = self.data + else: + data = self.c_kv(data) + x = self.attention(x, data) + x = self.c_proj(x) + return x + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + data_width: Optional[int] = None, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) + self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.heads = heads + self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.drop_path(self.c_proj(x)) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.attn = MultiheadAttention( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) + self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class CrossAttentionDecoder(nn.Module): + + def __init__( + self, + *, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + downsample_ratio: int = 1, + enable_ln_post: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary" + ): + super().__init__() + + self.enable_ln_post = enable_ln_post + self.fourier_embedder = fourier_embedder + self.downsample_ratio = downsample_ratio + self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) + if self.downsample_ratio != 1: + self.latents_proj = ops.Linear(width * downsample_ratio, width) + if self.enable_ln_post == False: + qk_norm = False + self.cross_attn_decoder = ResidualCrossAttentionBlock( + width=width, + mlp_expand_ratio=mlp_expand_ratio, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm + ) + + if self.enable_ln_post: + self.ln_post = ops.LayerNorm(width) + self.output_proj = ops.Linear(width, out_channels) + self.label_type = label_type + self.count = 0 + + def forward(self, queries=None, query_embeddings=None, latents=None): + if query_embeddings is None: + query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype)) + self.count += query_embeddings.shape[1] + if self.downsample_ratio != 1: + latents = self.latents_proj(latents) + x = self.cross_attn_decoder(query_embeddings, latents) + if self.enable_ln_post: + x = self.ln_post(x) + occ = self.output_proj(x) + return occ + + +class ShapeVAE(nn.Module): + def __init__( + self, + *, + embed_dim: int, + width: int, + heads: int, + num_decoder_layers: int, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + include_pi: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + drop_path_rate: float = 0.0, + scale_factor: float = 1.0, + ): + super().__init__() + self.geo_decoder_ln_post = geo_decoder_ln_post + + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + self.post_kl = ops.Linear(embed_dim, width) + + self.transformer = Transformer( + width=width, + layers=num_decoder_layers, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + + self.geo_decoder = CrossAttentionDecoder( + fourier_embedder=self.fourier_embedder, + out_channels=1, + mlp_expand_ratio=geo_decoder_mlp_expand_ratio, + downsample_ratio=geo_decoder_downsample_ratio, + enable_ln_post=self.geo_decoder_ln_post, + width=width // geo_decoder_downsample_ratio, + heads=heads // geo_decoder_downsample_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + label_type=label_type, + ) + + self.volume_decoder = VanillaVolumeDecoder() + self.scale_factor = scale_factor + + def decode(self, latents, **kwargs): + latents = self.post_kl(latents.movedim(-2, -1)) + latents = self.transformer(latents) + + bounds = kwargs.get("bounds", 1.01) + num_chunks = kwargs.get("num_chunks", 8000) + octree_resolution = kwargs.get("octree_resolution", 256) + enable_pbar = kwargs.get("enable_pbar", True) + + grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) + return grid_logits + + def encode(self, x): + return None diff --git a/comfy/model_base.py b/comfy/model_base.py index 976702b6..f02406ac 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -36,6 +36,7 @@ import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model import comfy.ldm.lumina.model import comfy.ldm.wan.model +import comfy.ldm.hunyuan3d.model import comfy.model_management import comfy.patcher_extension @@ -1013,3 +1014,18 @@ class WAN21(BaseModel): if clip_vision_output is not None: out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) return out + +class Hunyuan3Dv2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + guidance = kwargs.get("guidance", 5.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 403da585..f9e96ab7 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -323,6 +323,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "t2v" return dit_config + if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D + in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape + dit_config = {} + dit_config["image_model"] = "hunyuan3d2" + dit_config["in_channels"] = in_shape[1] + dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1] + dit_config["hidden_size"] = in_shape[0] + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 16 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') + dit_config["qkv_bias"] = True + dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index 3d72a04d..4160fa89 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.hunyuan3d.vae import yaml import math @@ -412,6 +413,16 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: + self.latent_dim = 1 + ln_post = "geo_decoder.ln_post.weight" in sd + inner_size = sd["geo_decoder.output_proj.weight"].shape[1] + downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size + mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size + self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) + ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} + self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -498,7 +509,7 @@ class VAE: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - def decode(self, samples_in): + def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None try: @@ -510,7 +521,7 @@ class VAE: for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) + out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b4d7bfe2..b5c3194c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -959,6 +959,42 @@ class WAN21_I2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=True, device=device) return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +class Hunyuan3Dv2(supported_models_base.BASE): + unet_config = { + "image_model": "hunyuan3d2", + } + + unet_extra_config = {} + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.0, + } + + clip_vision_prefix = "conditioner.main_image_encoder.model." + vae_key_prefix = ["vae."] + + latent_format = latent_formats.Hunyuan3Dv2 + + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Hunyuan3Dv2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None + +class Hunyuan3Dv2mini(Hunyuan3Dv2): + unet_config = { + "image_model": "hunyuan3d2", + "depth": 8, + } + + latent_format = latent_formats.Hunyuan3Dv2mini + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py new file mode 100644 index 00000000..6abcde1f --- /dev/null +++ b/comfy_extras/nodes_hunyuan3d.py @@ -0,0 +1,410 @@ +import torch +import os +import json +import struct +import numpy as np +from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch +import folder_paths +import comfy.model_management +from comfy.cli_args import args + + +class EmptyLatentHunyuan3Dv2: + @classmethod + def INPUT_TYPES(s): + return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/3d" + + def generate(self, resolution, batch_size): + latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) + return ({"samples": latent, "type": "hunyuan3dv2"}, ) + + +class Hunyuan3Dv2Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, clip_vision_output): + embeds = clip_vision_output.last_hidden_state + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class Hunyuan3Dv2ConditioningMultiView: + @classmethod + def INPUT_TYPES(s): + return {"required": {}, + "optional": {"front": ("CLIP_VISION_OUTPUT",), + "left": ("CLIP_VISION_OUTPUT",), + "back": ("CLIP_VISION_OUTPUT",), + "right": ("CLIP_VISION_OUTPUT",), }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, front=None, left=None, back=None, right=None): + all_embeds = [front, left, back, right] + out = [] + pos_embeds = None + for i, e in enumerate(all_embeds): + if e is not None: + if pos_embeds is None: + pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4)) + out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1)) + + embeds = torch.cat(out, dim=1) + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class VOXEL: + def __init__(self, data): + self.data = data + + +class VAEDecodeHunyuan3D: + @classmethod + def INPUT_TYPES(s): + return {"required": {"samples": ("LATENT", ), + "vae": ("VAE", ), + "num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}), + "octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}), + }} + RETURN_TYPES = ("VOXEL",) + FUNCTION = "decode" + + CATEGORY = "latent/3d" + + def decode(self, vae, samples, num_chunks, octree_resolution): + voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) + return (voxels, ) + + +def voxel_to_mesh(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + binary = (voxels > threshold).float() + padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0) + + D, H, W = binary.shape + + neighbors = torch.tensor([ + [0, 0, 1], + [0, 0, -1], + [0, 1, 0], + [0, -1, 0], + [1, 0, 0], + [-1, 0, 0] + ], device=device) + + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + solid_mask = binary.flatten() > 0 + solid_indices = voxel_indices[solid_mask] + + corner_offsets = [ + torch.tensor([ + [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0] + ], device=device), + torch.tensor([ + [1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0] + ], device=device) + ] + + all_vertices = [] + all_indices = [] + + vertex_count = 0 + + for face_idx, offset in enumerate(neighbors): + neighbor_indices = solid_indices + offset + + padded_indices = neighbor_indices + 1 + + is_exposed = padded[ + padded_indices[:, 0], + padded_indices[:, 1], + padded_indices[:, 2] + ] == 0 + + if not is_exposed.any(): + continue + + exposed_indices = solid_indices[is_exposed] + + corners = corner_offsets[face_idx].unsqueeze(0) + + face_vertices = exposed_indices.unsqueeze(1) + corners + + all_vertices.append(face_vertices.reshape(-1, 3)) + + num_faces = exposed_indices.shape[0] + face_indices = torch.arange( + vertex_count, + vertex_count + 4 * num_faces, + device=device + ).reshape(-1, 4) + + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 1]], dim=1)) + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 3], face_indices[:, 2]], dim=1)) + + vertex_count += 4 * num_faces + + vertices = torch.cat(all_vertices, dim=0) + faces = torch.cat(all_indices, dim=0) + + v_min = 0 + v_max = max(voxels.shape) + + vertices = vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + vertices = vertices / scale + + return vertices, faces + + +class MESH: + def __init__(self, vertices, faces): + self.vertices = vertices + self.faces = faces + + +class VoxelToMeshBasic: + @classmethod + def INPUT_TYPES(s): + return {"required": {"voxel": ("VOXEL", ), + "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MESH",) + FUNCTION = "decode" + + CATEGORY = "3d" + + def decode(self, voxel, threshold): + vertices = [] + faces = [] + for x in voxel.data: + v, f = voxel_to_mesh(x, threshold=threshold, device=None) + vertices.append(v) + faces.append(f) + + return (MESH(torch.stack(vertices), torch.stack(faces)), ) + + +def save_glb(vertices, faces, filepath, metadata=None): + """ + Save PyTorch tensor vertices and faces as a GLB file without external dependencies. + + Parameters: + vertices: torch.Tensor of shape (N, 3) - The vertex coordinates + faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces) + filepath: str - Output filepath (should end with .glb) + """ + + # Convert tensors to numpy arrays + vertices_np = vertices.cpu().numpy().astype(np.float32) + faces_np = faces.cpu().numpy().astype(np.uint32) + + vertices_buffer = vertices_np.tobytes() + indices_buffer = faces_np.tobytes() + + def pad_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b'\x00' * padding_length + + vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) + indices_buffer_padded = pad_to_4_bytes(indices_buffer) + + buffer_data = vertices_buffer_padded + indices_buffer_padded + + vertices_byte_length = len(vertices_buffer) + vertices_byte_offset = 0 + indices_byte_length = len(indices_buffer) + indices_byte_offset = len(vertices_buffer_padded) + + gltf = { + "asset": {"version": "2.0", "generator": "ComfyUI"}, + "buffers": [ + { + "byteLength": len(buffer_data) + } + ], + "bufferViews": [ + { + "buffer": 0, + "byteOffset": vertices_byte_offset, + "byteLength": vertices_byte_length, + "target": 34962 # ARRAY_BUFFER + }, + { + "buffer": 0, + "byteOffset": indices_byte_offset, + "byteLength": indices_byte_length, + "target": 34963 # ELEMENT_ARRAY_BUFFER + } + ], + "accessors": [ + { + "bufferView": 0, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(vertices_np), + "type": "VEC3", + "max": vertices_np.max(axis=0).tolist(), + "min": vertices_np.min(axis=0).tolist() + }, + { + "bufferView": 1, + "byteOffset": 0, + "componentType": 5125, # UNSIGNED_INT + "count": faces_np.size, + "type": "SCALAR" + } + ], + "meshes": [ + { + "primitives": [ + { + "attributes": { + "POSITION": 0 + }, + "indices": 1, + "mode": 4 # TRIANGLES + } + ] + } + ], + "nodes": [ + { + "mesh": 0 + } + ], + "scenes": [ + { + "nodes": [0] + } + ], + "scene": 0 + } + + if metadata is not None: + gltf["asset"]["extras"] = metadata + + # Convert the JSON to bytes + gltf_json = json.dumps(gltf).encode('utf8') + + def pad_json_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b' ' * padding_length + + gltf_json_padded = pad_json_to_4_bytes(gltf_json) + + # Create the GLB header + # Magic glTF + glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) + + # Create JSON chunk header (chunk type 0) + json_chunk_header = struct.pack(' Date: Wed, 19 Mar 2025 19:55:24 -0400 Subject: [PATCH 03/37] Fix orientation of hunyuan 3d model. --- comfy/ldm/hunyuan3d/vae.py | 2 +- comfy_extras/nodes_hunyuan3d.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 311c9b41..5eb2c654 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -581,7 +581,7 @@ class ShapeVAE(nn.Module): enable_pbar = kwargs.get("enable_pbar", True) grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) - return grid_logits + return grid_logits.movedim(-2, -1) def encode(self, x): return None diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 6abcde1f..ac2cff3a 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -185,8 +185,8 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None): device=device ).reshape(-1, 4) - all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 1]], dim=1)) - all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 3], face_indices[:, 2]], dim=1)) + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1)) + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1)) vertex_count += 4 * num_faces @@ -202,6 +202,7 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None): if scale > 0: vertices = vertices / scale + vertices = torch.fliplr(vertices) return vertices, faces From 3872b43d4ba44ca93eae305298a6474efafa3eb7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 20 Mar 2025 04:52:31 -0400 Subject: [PATCH 04/37] A few fixes for the hunyuan3d models. --- comfy/sd.py | 5 +++-- comfy/supported_models.py | 2 ++ comfy_extras/nodes_hunyuan3d.py | 8 ++++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 4160fa89..d096f496 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -419,10 +419,11 @@ class VAE: inner_size = sd["geo_decoder.output_proj.weight"].shape[1] downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size - self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO + self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b5c3194c..be3aede6 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -971,6 +971,8 @@ class Hunyuan3Dv2(supported_models_base.BASE): "shift": 1.0, } + memory_usage_factor = 3.5 + clip_vision_prefix = "conditioner.main_image_encoder.model." vae_key_prefix = ["vae."] diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index ac2cff3a..1ca7c2fe 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -190,8 +190,12 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None): vertex_count += 4 * num_faces - vertices = torch.cat(all_vertices, dim=0) - faces = torch.cat(all_indices, dim=0) + if len(all_vertices) > 0: + vertices = torch.cat(all_vertices, dim=0) + faces = torch.cat(all_indices, dim=0) + else: + vertices = torch.zeros((1, 3)) + faces = torch.zeros((1, 3)) v_min = 0 v_max = max(voxels.shape) From 8b9ce4ed18c24db2b7195b8d33932e516fcb3d85 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 21 Mar 2025 00:17:36 -0400 Subject: [PATCH 05/37] Update frontend to 1.13 (#7331) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 70689bc9..ceec006d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.12.14 +comfyui-frontend-package==1.13.9 torch torchsde torchvision From a4a956dbbdcd9b3072d748f826394dd3223a094b Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 21 Mar 2025 01:47:18 -0400 Subject: [PATCH 06/37] Add backend primitive nodes (#7328) * Add backend primitive nodes * Add control after generate to int primitive --- comfy_extras/nodes_primitive.py | 79 +++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 80 insertions(+) create mode 100644 comfy_extras/nodes_primitive.py diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py new file mode 100644 index 00000000..b770104f --- /dev/null +++ b/comfy_extras/nodes_primitive.py @@ -0,0 +1,79 @@ +# Primitive nodes that are evaluated at backend. +from __future__ import annotations + +from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO + + +class String(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.STRING, {})}, + } + + RETURN_TYPES = (IO.STRING,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: str) -> tuple[str]: + return (value,) + + +class Int(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.INT, {"control_after_generate": True})}, + } + + RETURN_TYPES = (IO.INT,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: int) -> tuple[int]: + return (value,) + + +class Float(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.FLOAT, {})}, + } + + RETURN_TYPES = (IO.FLOAT,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: float) -> tuple[float]: + return (value,) + + +class Boolean(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.BOOLEAN, {})}, + } + + RETURN_TYPES = (IO.BOOLEAN,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: bool) -> tuple[bool]: + return (value,) + + +NODE_CLASS_MAPPINGS = { + "PrimitiveString": String, + "PrimitiveInt": Int, + "PrimitiveFloat": Float, + "PrimitiveBoolean": Boolean, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PrimitiveString": "String", + "PrimitiveInt": "Int", + "PrimitiveFloat": "Float", + "PrimitiveBoolean": "Boolean", +} diff --git a/nodes.py b/nodes.py index f89c328e..a9c931df 100644 --- a/nodes.py +++ b/nodes.py @@ -2265,6 +2265,7 @@ def init_builtin_extra_nodes(): "nodes_lumina2.py", "nodes_wan.py", "nodes_hunyuan3d.py", + "nodes_primitive.py", ] import_failed = [] From 095610717000bffd477a7e72988d1fb2299afacb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 21 Mar 2025 06:32:20 -0400 Subject: [PATCH 07/37] Nodes to convert images to YUV and back. Can be used to convert an image to black and white. --- comfy_extras/nodes_morphology.py | 38 ++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index b1372b8c..075b26c4 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -2,6 +2,7 @@ import torch import comfy.model_management from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat +import kornia.color class Morphology: @@ -40,8 +41,45 @@ class Morphology: img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) return (img_out,) + +class ImageRGBToYUV: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") + RETURN_NAMES = ("Y", "U", "V") + FUNCTION = "execute" + + CATEGORY = "image/batch" + + def execute(self, image): + out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) + return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) + +class ImageYUVToRGB: + @classmethod + def INPUT_TYPES(s): + return {"required": {"Y": ("IMAGE",), + "U": ("IMAGE",), + "V": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "execute" + + CATEGORY = "image/batch" + + def execute(self, Y, U, V): + image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) + out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) + return (out,) + NODE_CLASS_MAPPINGS = { "Morphology": Morphology, + "ImageRGBToYUV": ImageRGBToYUV, + "ImageYUVToRGB": ImageYUVToRGB, } NODE_DISPLAY_NAME_MAPPINGS = { From 0cf227469929f74ae5ae887f3f7fa7e490e5e9d0 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 21 Mar 2025 13:50:09 -0400 Subject: [PATCH 08/37] Update frontend to 1.14 (#7343) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ceec006d..c78d3c22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.13.9 +comfyui-frontend-package==1.14.5 torch torchsde torchvision From 83e839a89be1dc6db0923bea45ff9eae43a8ea01 Mon Sep 17 00:00:00 2001 From: thot experiment <94414189+thot-experiment@users.noreply.github.com> Date: Fri, 21 Mar 2025 11:04:15 -0700 Subject: [PATCH 09/37] Native LotusD Implementation (#7125) * draft pass at a native comfy implementation of Lotus-D depth and normal est * fix model_sampling kludges * fix ruff --------- Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> --- comfy/model_base.py | 14 ++++++++++++++ comfy/model_detection.py | 7 ++++++- comfy/supported_models.py | 18 ++++++++++++++++- comfy_extras/nodes_lotus.py | 29 ++++++++++++++++++++++++++++ comfy_extras/nodes_model_advanced.py | 8 +++++++- nodes.py | 1 + 6 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 comfy_extras/nodes_lotus.py diff --git a/comfy/model_base.py b/comfy/model_base.py index f02406ac..2fb4b145 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -140,6 +140,7 @@ class BaseModel(torch.nn.Module): def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.model_sampling.calculate_input(sigma, x) + if c_concat is not None: xc = torch.cat([xc] + [c_concat], dim=1) @@ -601,6 +602,19 @@ class SDXL_instructpix2pix(IP2P, SDXL): else: self.process_ip2p_image_in = lambda image: image #diffusers ip2p +class Lotus(BaseModel): + def extra_conds(self, **kwargs): + out = {} + cross_attn = kwargs.get("cross_attn", None) + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + device = kwargs["device"] + task_emb = torch.tensor([1, 0]).float().to(device) + task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0) + out['y'] = comfy.conds.CONDRegular(task_emb) + return out + + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index f9e96ab7..4217f583 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -682,8 +682,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'use_temporal_attention': False, 'use_temporal_resblock': False} + LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4, + 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8, + 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] + supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] for unet_config in supported_models: matches = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index be3aede6..fad00d35 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -506,6 +506,22 @@ class SDXL_instructpix2pix(SDXL): def get_model(self, state_dict, prefix="", device=None): return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) +class LotusD(SD20): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "use_temporal_attention": False, + "adm_in_channels": 4, + "in_channels": 4, + } + + unet_extra_config = { + "num_classes": 'sequential' + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Lotus(self, device=device) + class SD3(supported_models_base.BASE): unet_config = { "in_channels": 16, @@ -997,6 +1013,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_lotus.py b/comfy_extras/nodes_lotus.py new file mode 100644 index 00000000..739dbdd3 --- /dev/null +++ b/comfy_extras/nodes_lotus.py @@ -0,0 +1,29 @@ +import torch +import comfy.model_management as mm + +class LotusConditioning: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + } + + RETURN_TYPES = ("CONDITIONING",) + RETURN_NAMES = ("conditioning",) + FUNCTION = "conditioning" + CATEGORY = "conditioning/lotus" + + def conditioning(self): + device = mm.get_torch_device() + #lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change + #and getting parity with the reference implementation would otherwise require inference and 800mb of tensors + prompt_embeds = torch.tensor([[[-0.3134765625, -0.447509765625, -0.00823974609375, -0.22802734375, 0.1785888671875, -0.2342529296875, -0.2188720703125, -0.0089111328125, -0.31396484375, 0.196533203125, -0.055877685546875, -0.3828125, -0.0965576171875, 0.0073394775390625, -0.284423828125, 0.07470703125, -0.086181640625, -0.211181640625, 0.0599365234375, 0.10693359375, 0.0007929801940917969, -0.78076171875, -0.382568359375, -0.1851806640625, -0.140625, -0.0936279296875, -0.1229248046875, -0.152099609375, -0.203857421875, -0.2349853515625, -0.2437744140625, -0.10858154296875, -0.08990478515625, 0.08892822265625, -0.2391357421875, -0.1611328125, -0.427978515625, -0.1336669921875, -0.27685546875, -0.1781005859375, -0.3857421875, 0.251953125, -0.055999755859375, -0.0712890625, -0.00130462646484375, 0.033477783203125, -0.26416015625, 0.07171630859375, -0.0090789794921875, -0.2025146484375, -0.2763671875, -0.09869384765625, -0.45751953125, -0.23095703125, 0.004528045654296875, -0.369140625, -0.366943359375, -0.205322265625, -0.1505126953125, -0.45166015625, -0.2059326171875, 0.0168609619140625, -0.305419921875, -0.150634765625, 0.02685546875, -0.609375, -0.019012451171875, 0.050445556640625, -0.0084381103515625, -0.31005859375, -0.184326171875, -0.15185546875, 0.06732177734375, 0.150390625, -0.10919189453125, -0.08837890625, -0.50537109375, -0.389892578125, -0.0294342041015625, -0.10491943359375, -0.187255859375, -0.43212890625, -0.328125, -1.060546875, 0.011871337890625, 0.04730224609375, -0.09521484375, -0.07452392578125, -0.29296875, -0.109130859375, -0.250244140625, -0.3828125, -0.171875, -0.03399658203125, -0.15478515625, -0.1861572265625, -0.2398681640625, 0.1053466796875, -0.22314453125, -0.1932373046875, -0.18798828125, -0.430419921875, -0.05364990234375, -0.474609375, -0.261474609375, -0.1077880859375, -0.439208984375, 0.08966064453125, -0.185302734375, -0.338134765625, -0.297119140625, -0.298583984375, -0.175537109375, -0.373291015625, -0.1397705078125, -0.260498046875, -0.383544921875, -0.09979248046875, -0.319580078125, -0.06884765625, -0.4365234375, -0.183837890625, -0.393310546875, -0.002277374267578125, 0.11236572265625, -0.260498046875, -0.2242431640625, -0.19384765625, -0.51123046875, 0.03216552734375, -0.048004150390625, -0.279052734375, -0.2978515625, -0.255615234375, 0.115478515625, -4.08984375, -0.1668701171875, -0.278076171875, -0.5712890625, -0.1385498046875, -0.244384765625, -0.41455078125, -0.244140625, -0.0677490234375, -0.141357421875, -0.11590576171875, -0.1439208984375, -0.0185394287109375, -2.490234375, -0.1549072265625, -0.2305908203125, -0.3828125, -0.1173095703125, -0.08258056640625, -0.1719970703125, -0.325439453125, -0.292724609375, -0.08154296875, -0.412353515625, -0.3115234375, -0.00832366943359375, 0.00489044189453125, -0.2236328125, -0.151123046875, -0.457275390625, -0.135009765625, -0.163330078125, -0.0819091796875, 0.06689453125, 0.0209197998046875, -0.11907958984375, -0.10369873046875, -0.2998046875, -0.478759765625, -0.07940673828125, -0.01517486572265625, -0.3017578125, -0.343994140625, -0.258544921875, -0.44775390625, -0.392822265625, -0.0255584716796875, -0.2998046875, 0.10833740234375, -0.271728515625, -0.36181640625, -0.255859375, -0.2056884765625, -0.055450439453125, 0.060516357421875, -0.45751953125, -0.2322998046875, -0.1737060546875, -0.40576171875, -0.2286376953125, -0.053070068359375, -0.0283660888671875, -0.1898193359375, -4.291534423828125e-05, -0.6591796875, -0.1717529296875, -0.479736328125, -0.1400146484375, -0.40771484375, 0.154296875, 0.003101348876953125, 0.00661468505859375, -0.2073974609375, -0.493408203125, 2.171875, -0.45361328125, -0.283935546875, -0.302001953125, -0.25146484375, -0.207275390625, -0.1524658203125, -0.72998046875, -0.08203125, 0.053192138671875, -0.2685546875, 0.1834716796875, -0.270263671875, -0.091552734375, -0.08319091796875, -0.1297607421875, -0.453857421875, 0.0687255859375, 0.0268096923828125, -0.16552734375, -0.4208984375, -0.1552734375, -0.057373046875, -0.300537109375, -0.04541015625, -0.486083984375, -0.2205810546875, -0.39013671875, 0.007488250732421875, -0.005329132080078125, -0.09759521484375, -0.1448974609375, -0.21923828125, -0.429443359375, -0.40087890625, -0.19384765625, -0.064453125, -0.0306243896484375, -0.045806884765625, -0.056793212890625, 0.119384765625, -0.2073974609375, -0.356201171875, -0.168212890625, -0.291748046875, -0.289794921875, -0.205322265625, -0.419677734375, -0.478271484375, -0.2037353515625, -0.368408203125, -0.186279296875, -0.427734375, -0.1756591796875, 0.07501220703125, -0.2457275390625, -0.03692626953125, 0.003997802734375, -5.7578125, -0.01052093505859375, -0.2305908203125, -0.2252197265625, -0.197509765625, -0.1566162109375, -0.1668701171875, -0.383056640625, -0.05413818359375, 0.12188720703125, -0.369873046875, -0.0184478759765625, -0.150146484375, -0.51123046875, -0.45947265625, -0.1561279296875, 0.060455322265625, 0.043487548828125, -0.1370849609375, -0.069091796875, -0.285888671875, -0.44482421875, -0.2374267578125, -0.2191162109375, -0.434814453125, -0.0360107421875, 0.1298828125, 0.0217742919921875, -0.51220703125, -0.13525390625, -0.09381103515625, -0.276611328125, -0.171875, -0.17138671875, -0.4443359375, -0.2178955078125, -0.269775390625, -0.38623046875, -0.31591796875, -0.42333984375, -0.280029296875, -0.255615234375, -0.17041015625, 0.06268310546875, -0.1878662109375, -0.00677490234375, -0.23583984375, -0.08795166015625, -0.2232666015625, -0.1719970703125, -0.484130859375, -0.328857421875, 0.04669189453125, -0.0419921875, -0.11114501953125, 0.02313232421875, -0.0033130645751953125, -0.6005859375, 0.09051513671875, -0.1884765625, -0.262939453125, -0.375732421875, -0.525390625, -0.1170654296875, -0.3779296875, -0.242919921875, -0.419921875, 0.0665283203125, -0.343017578125, 0.06658935546875, -0.346435546875, -0.1363525390625, -0.2000732421875, -0.3837890625, 0.028167724609375, 0.043853759765625, -0.0171051025390625, -0.477294921875, -0.107421875, -0.129150390625, -0.319580078125, -0.32177734375, -0.4951171875, -0.010589599609375, -0.1778564453125, -0.40234375, -0.0810546875, 0.03314208984375, -0.13720703125, -0.31591796875, -0.048248291015625, -0.274658203125, -0.0689697265625, -0.027130126953125, -0.0953369140625, 0.146728515625, -0.38671875, -0.025390625, -0.42333984375, -0.41748046875, -0.379638671875, -0.1978759765625, -0.533203125, -0.33544921875, 0.0694580078125, -0.322998046875, -0.1876220703125, 0.0094451904296875, 0.1839599609375, -0.254150390625, -0.30078125, -0.09228515625, -0.0885009765625, 0.12371826171875, 0.1500244140625, -0.12152099609375, -0.29833984375, 0.03924560546875, -0.1470947265625, -0.1610107421875, -0.2049560546875, -0.01708984375, -0.2470703125, -0.1522216796875, -0.25830078125, 0.10870361328125, -0.302490234375, -0.2376708984375, -0.360107421875, -0.443359375, -0.0784912109375, -0.63623046875, -0.0980224609375, -0.332275390625, -0.1749267578125, -0.30859375, -0.1968994140625, -0.250244140625, -0.447021484375, -0.18408203125, -0.006908416748046875, -0.2044677734375, -0.2548828125, -0.369140625, -0.11328125, -0.1103515625, -0.27783203125, -0.325439453125, 0.01381683349609375, 0.036773681640625, -0.1458740234375, -0.34619140625, -0.232177734375, -0.0562744140625, -0.4482421875, -0.21875, -0.0855712890625, -0.276123046875, -0.1544189453125, -0.223388671875, -0.259521484375, 0.0865478515625, -0.0038013458251953125, -0.340087890625, -0.076171875, -0.25341796875, -0.0007548332214355469, -0.060455322265625, -0.352294921875, 0.035736083984375, -0.2181396484375, -0.2318115234375, -0.1707763671875, 0.018646240234375, 0.093505859375, -0.197021484375, 0.033477783203125, -0.035247802734375, 0.0440673828125, -0.2056884765625, -0.040924072265625, -0.05865478515625, 0.056884765625, -0.08807373046875, -0.10845947265625, 0.09564208984375, -0.10888671875, -0.332275390625, -0.1119384765625, -0.115478515625, 13.0234375, 0.0030040740966796875, -0.53662109375, -0.1856689453125, -0.068115234375, -0.143798828125, -0.177978515625, -0.32666015625, -0.353515625, -0.1563720703125, -0.3203125, 0.0085906982421875, -0.1043701171875, -0.365478515625, -0.303466796875, -0.34326171875, -0.410888671875, -0.03790283203125, -0.11419677734375, -0.2939453125, 0.074462890625, -0.21826171875, 0.0242767333984375, -0.226318359375, -0.353515625, -0.177734375, -0.169189453125, -0.2423095703125, -0.12115478515625, -0.07843017578125, -0.341064453125, -0.2117919921875, -0.505859375, -0.544921875, -0.3935546875, -0.10772705078125, -0.2054443359375, -0.136474609375, -0.1796875, -0.396240234375, -0.1971435546875, -0.68408203125, -0.032684326171875, -0.03863525390625, -0.0709228515625, -0.1005859375, -0.156005859375, -0.3837890625, -0.319580078125, 0.11102294921875, -0.394287109375, 0.0799560546875, -0.50341796875, -0.1572265625, 0.004131317138671875, -0.12286376953125, -0.2347412109375, -0.29150390625, -0.10321044921875, -0.286376953125, 0.018798828125, -0.152099609375, -0.321044921875, 0.0191650390625, -0.11376953125, -0.54736328125, 0.15869140625, -0.257568359375, -0.2490234375, -0.3115234375, -0.09765625, -0.350830078125, -0.36376953125, -0.0771484375, -0.2298583984375, -0.30615234375, -0.052154541015625, -0.12091064453125, -0.40283203125, -0.1649169921875, 0.0206451416015625, -0.312744140625, -0.10308837890625, -0.50341796875, -0.1754150390625, -0.2003173828125, -0.173583984375, -0.204833984375, -0.1876220703125, -0.12176513671875, -0.06201171875, -0.03485107421875, -0.20068359375, -0.21484375, -0.246337890625, -0.006587982177734375, -0.09674072265625, -0.4658203125, -0.3994140625, -0.2210693359375, -0.09588623046875, -0.126220703125, -0.09222412109375, -0.145751953125, -0.217529296875, -0.289306640625, -0.28271484375, -0.1787109375, -0.169189453125, -0.359375, -0.21826171875, -0.043792724609375, -0.205322265625, -0.2900390625, -0.055419921875, -0.1490478515625, -0.340576171875, -0.045928955078125, -0.30517578125, -0.51123046875, -0.1046142578125, -0.349853515625, -0.10882568359375, -0.16748046875, -0.267333984375, -0.122314453125, -0.0985107421875, -0.3076171875, -0.1766357421875, -0.251708984375, 0.1964111328125, -0.2220458984375, -0.2349853515625, -0.035980224609375, -0.1749267578125, -0.237060546875, -0.480224609375, -0.240234375, -0.09539794921875, -0.2481689453125, -0.389404296875, -0.1748046875, -0.370849609375, -0.010650634765625, -0.147705078125, -0.0035457611083984375, -0.32568359375, -0.29931640625, -0.1395263671875, -0.28173828125, -0.09820556640625, -0.0176239013671875, -0.05926513671875, -0.0755615234375, -0.1746826171875, -0.283203125, -0.1617431640625, -0.4404296875, 0.046234130859375, -0.183837890625, -0.052032470703125, -0.24658203125, -0.11224365234375, -0.100830078125, -0.162841796875, -0.29736328125, -0.396484375, 0.11798095703125, -0.006496429443359375, -0.32568359375, -0.347900390625, -0.04595947265625, -0.09637451171875, -0.344970703125, -0.01166534423828125, -0.346435546875, -0.2861328125, -0.1845703125, -0.276611328125, -0.01312255859375, -0.395263671875, -0.50927734375, -0.1114501953125, -0.1861572265625, -0.2158203125, -0.1812744140625, 0.055419921875, -0.294189453125, 0.06500244140625, -0.1444091796875, -0.06365966796875, -0.18408203125, -0.0091705322265625, -0.1640625, -0.1856689453125, 0.090087890625, 0.024566650390625, -0.0195159912109375, -0.5546875, -0.301025390625, -0.438232421875, -0.072021484375, 0.030517578125, -0.1490478515625, 0.04888916015625, -0.23681640625, -0.1553955078125, -0.018096923828125, -0.229736328125, -0.2919921875, -0.355712890625, -0.285400390625, -0.1756591796875, -0.08355712890625, -0.416259765625, 0.022674560546875, -0.417236328125, 0.410400390625, -0.249755859375, 0.015625, -0.033599853515625, -0.040313720703125, -0.51708984375, -0.0518798828125, -0.08843994140625, -0.2022705078125, -0.3740234375, -0.285888671875, -0.176025390625, -0.292724609375, -0.369140625, -0.08367919921875, -0.356689453125, -0.38623046875, 0.06549072265625, 0.1669921875, -0.2099609375, -0.007434844970703125, 0.12890625, -0.0040740966796875, -0.2174072265625, -0.025115966796875, -0.2364501953125, -0.1695556640625, -0.0469970703125, -0.03924560546875, -0.36181640625, -0.047515869140625, -0.3154296875, -0.275634765625, -0.25634765625, -0.061920166015625, -0.12164306640625, -0.47314453125, -0.10784912109375, -0.74755859375, -0.13232421875, -0.32421875, -0.04998779296875, -0.286376953125, 0.10345458984375, -0.1710205078125, -0.388916015625, 0.12744140625, -0.3359375, -0.302490234375, -0.238525390625, -0.1455078125, -0.15869140625, -0.2427978515625, -0.0355224609375, -0.11944580078125, -0.31298828125, 0.11456298828125, -0.287841796875, -0.5439453125, -0.3076171875, -0.08642578125, -0.2408447265625, -0.283447265625, -0.428466796875, -0.085693359375, -0.1683349609375, 0.255126953125, 0.07635498046875, -0.38623046875, -0.2025146484375, -0.1331787109375, -0.10821533203125, -0.49951171875, 0.09130859375, -0.19677734375, -0.01904296875, -0.151123046875, -0.344482421875, -0.316650390625, -0.03900146484375, 0.1397705078125, 0.1334228515625, -0.037200927734375, -0.01861572265625, -0.1351318359375, -0.07037353515625, -0.380615234375, -0.34033203125, -0.06903076171875, 0.219970703125, 0.0132598876953125, -0.15869140625, -0.6376953125, 0.158935546875, -0.5283203125, -0.2320556640625, -0.185791015625, -0.2132568359375, -0.436767578125, -0.430908203125, -0.1763916015625, -0.0007672309875488281, -0.424072265625, -0.06719970703125, -0.347900390625, -0.14453125, -0.3056640625, -0.36474609375, -0.35986328125, -0.46240234375, -0.446044921875, -0.1905517578125, -0.1114501953125, -0.42919921875, -0.0643310546875, -0.3662109375, -0.4296875, -0.10968017578125, -0.2998046875, -0.1756591796875, -0.4052734375, -0.0841064453125, -0.252197265625, -0.047393798828125, 0.00434112548828125, -0.10040283203125, -0.271484375, -0.185302734375, -0.1910400390625, 0.10260009765625, 0.01393890380859375, -0.03350830078125, -0.33935546875, -0.329345703125, 0.0574951171875, -0.18896484375, -0.17724609375, -0.42919921875, -0.26708984375, -0.4189453125, -0.149169921875, -0.265625, -0.198974609375, -0.1722412109375, 0.1563720703125, -0.20947265625, -0.267822265625, -0.06353759765625, -0.365478515625, -0.340087890625, -0.3095703125, -0.320068359375, -0.0880126953125, -0.353759765625, -0.0005812644958496094, -0.1617431640625, -0.1866455078125, -0.201416015625, -0.181396484375, -0.2349853515625, -0.384765625, -0.5244140625, 0.01227569580078125, -0.21337890625, -0.30810546875, -0.17578125, -0.3037109375, -0.52978515625, -0.1561279296875, -0.296142578125, 0.057342529296875, -0.369384765625, -0.107666015625, -0.338623046875, -0.2060546875, -0.0213775634765625, -0.394775390625, -0.219482421875, -0.125732421875, -0.03997802734375, -0.42431640625, -0.134521484375, -0.2418212890625, -0.10504150390625, 0.1552734375, 0.1126708984375, -0.1427001953125, -0.133544921875, -0.111083984375, -0.375732421875, -0.2783203125, -0.036834716796875, -0.11053466796875, 0.2471923828125, -0.2529296875, -0.56494140625, -0.374755859375, -0.326416015625, 0.2137451171875, -0.09454345703125, -0.337158203125, -0.3359375, -0.34375, -0.0999755859375, -0.388671875, 0.0103302001953125, 0.14990234375, -0.2041015625, -0.39501953125, -0.39013671875, -0.1258544921875, 0.1453857421875, -0.250732421875, -0.06732177734375, -0.10638427734375, -0.032379150390625, -0.35888671875, -0.098876953125, -0.172607421875, 0.05126953125, -0.1956787109375, -0.183837890625, -0.37060546875, 0.1556396484375, -0.34375, -0.28662109375, -0.06982421875, -0.302490234375, -0.281005859375, -0.1640625, -0.5302734375, -0.1368408203125, -0.1268310546875, -0.35302734375, -0.1473388671875, -0.45556640625, -0.35986328125, -0.273681640625, -0.2249755859375, -0.1893310546875, 0.09356689453125, -0.248291015625, -0.197998046875, -0.3525390625, -0.30126953125, -0.228271484375, -0.2421875, -0.0906982421875, 0.227783203125, -0.296875, -0.009796142578125, -0.2939453125, -0.1021728515625, -0.215576171875, -0.267822265625, -0.052642822265625, 0.203369140625, -0.1417236328125, 0.18505859375, 0.12347412109375, -0.0972900390625, -0.54052734375, -0.430419921875, -0.0906982421875, -0.5419921875, -0.22900390625, -0.0625, -0.12152099609375, -0.495849609375, -0.206787109375, -0.025848388671875, 0.039031982421875, -0.453857421875, -0.318359375, -0.426025390625, -0.3701171875, -0.2169189453125, 0.0845947265625, -0.045654296875, 0.11090087890625, 0.0012454986572265625, 0.2066650390625, -0.046356201171875, -0.2337646484375, -0.295654296875, 0.057891845703125, -0.1639404296875, -0.0535888671875, -0.2607421875, -0.1488037109375, -0.16015625, -0.54345703125, -0.2305908203125, -0.55029296875, -0.178955078125, -0.222412109375, -0.0711669921875, -0.12298583984375, -0.119140625, -0.253662109375, -0.33984375, -0.11322021484375, -0.10723876953125, -0.205078125, -0.360595703125, 0.085205078125, -0.252197265625, -0.365966796875, -0.26953125, 0.2000732421875, -0.50634765625, 0.05706787109375, -0.3115234375, 0.0242919921875, -0.1689453125, -0.2401123046875, -0.3759765625, -0.2125244140625, 0.076416015625, -0.489013671875, -0.11749267578125, -0.55908203125, -0.313232421875, -0.572265625, -0.1387939453125, -0.037078857421875, -0.385498046875, 0.0323486328125, -0.39404296875, -0.05072021484375, -0.10430908203125, -0.10919189453125, -0.28759765625, -0.37451171875, -0.016937255859375, -0.2200927734375, -0.296875, -0.0286712646484375, -0.213134765625, 0.052001953125, -0.052337646484375, -0.253662109375, 0.07269287109375, -0.2498779296875, -0.150146484375, -0.09930419921875, -0.343505859375, 0.254150390625, -0.032440185546875, -0.296142578125], [1.4111328125, 0.00757598876953125, -0.428955078125, 0.089599609375, 0.0227813720703125, -0.0350341796875, -1.0986328125, 0.194091796875, 2.115234375, -0.75439453125, 0.269287109375, -0.73486328125, -1.1025390625, -0.050262451171875, -0.5830078125, 0.0268707275390625, -0.603515625, -0.6025390625, -1.1689453125, 0.25048828125, -0.4189453125, -0.5517578125, -0.30322265625, 0.7724609375, 0.931640625, -0.1422119140625, 2.27734375, -0.56591796875, 1.013671875, -0.9638671875, -0.66796875, -0.8125, 1.3740234375, -1.060546875, -1.029296875, -1.6796875, 0.62890625, 0.49365234375, 0.671875, 0.99755859375, -1.0185546875, -0.047027587890625, -0.374267578125, 0.2354736328125, 1.4970703125, -1.5673828125, 0.448974609375, 0.2078857421875, -1.060546875, -0.171875, -0.6201171875, -0.1607666015625, 0.7548828125, -0.58935546875, -0.2052001953125, 0.060791015625, 0.200439453125, 3.154296875, -3.87890625, 2.03515625, 1.126953125, 0.1640625, -1.8447265625, 0.002620697021484375, 0.7998046875, -0.337158203125, 0.47216796875, -0.5849609375, 0.9970703125, 0.3935546875, 1.22265625, -1.5048828125, -0.65673828125, 1.1474609375, -1.73046875, -1.8701171875, 1.529296875, -0.6787109375, -1.4453125, 1.556640625, -0.327392578125, 2.986328125, -0.146240234375, -2.83984375, 0.303466796875, -0.71728515625, -0.09698486328125, -0.2423095703125, 0.6767578125, -2.197265625, -0.86279296875, -0.53857421875, -1.2236328125, 1.669921875, -1.1689453125, -0.291259765625, -0.54736328125, -0.036346435546875, 1.041015625, -1.7265625, -0.6064453125, -0.1634521484375, 0.2381591796875, 0.65087890625, -1.169921875, 1.9208984375, 0.5634765625, 0.37841796875, 0.798828125, -1.021484375, -0.4091796875, 2.275390625, -0.302734375, -1.7783203125, 1.0458984375, 1.478515625, 0.708984375, -1.541015625, -0.0006041526794433594, 1.1884765625, 2.041015625, 0.560546875, -0.1131591796875, 1.0341796875, 0.06121826171875, 2.6796875, -0.53369140625, -1.2490234375, -0.7333984375, -1.017578125, -1.0078125, 1.3212890625, -0.47607421875, -1.4189453125, 0.54052734375, -0.796875, -0.73095703125, -1.412109375, -0.94873046875, -2.2734375, -1.1220703125, -1.3837890625, -0.5087890625, -1.0380859375, -0.93603515625, -0.58349609375, -1.0703125, -1.10546875, -2.60546875, 0.062225341796875, 0.38232421875, -0.411376953125, -0.369140625, -0.9833984375, -0.7294921875, -0.181396484375, -0.47216796875, -0.56884765625, -0.11041259765625, -2.673828125, 0.27783203125, -0.857421875, 0.9296875, 1.9580078125, 0.1385498046875, -1.91796875, -1.529296875, 0.53857421875, 0.509765625, -0.90380859375, -0.0947265625, -2.083984375, 0.9228515625, -0.28564453125, -0.80859375, -0.093505859375, -0.6015625, -1.255859375, 0.6533203125, 0.327880859375, -0.07598876953125, -0.22705078125, -0.30078125, -0.5185546875, -1.6044921875, 1.5927734375, 1.416015625, -0.91796875, -0.276611328125, -0.75830078125, -1.1689453125, -1.7421875, 1.0546875, -0.26513671875, -0.03314208984375, 0.278076171875, -1.337890625, 0.055023193359375, 0.10546875, -1.064453125, 1.048828125, -1.4052734375, -1.1240234375, -0.51416015625, -1.05859375, -1.7265625, -1.1328125, 0.43310546875, -2.576171875, -2.140625, -0.79345703125, 0.50146484375, 1.96484375, 0.98583984375, 0.337646484375, -0.77978515625, 0.85498046875, -0.65185546875, -0.484375, 2.708984375, 0.55810546875, -0.147216796875, -0.5537109375, -0.75439453125, -1.736328125, 1.1259765625, -1.095703125, -0.2587890625, 2.978515625, 0.335205078125, 0.357666015625, -0.09356689453125, 0.295654296875, -0.23779296875, 1.5751953125, 0.10400390625, 1.7001953125, -0.72900390625, -1.466796875, -0.2012939453125, 0.634765625, -0.1556396484375, -2.01171875, 0.32666015625, 0.047454833984375, -0.1671142578125, -0.78369140625, -0.994140625, 0.7802734375, -0.1429443359375, -0.115234375, 0.53271484375, -0.96142578125, -0.064208984375, 1.396484375, 1.654296875, -1.6015625, -0.77392578125, 0.276123046875, -0.42236328125, 0.8642578125, 0.533203125, 0.397216796875, -1.21484375, 0.392578125, -0.501953125, -0.231689453125, 1.474609375, 1.6669921875, 1.8662109375, -1.2998046875, 0.223876953125, -0.51318359375, -0.437744140625, -1.16796875, -0.7724609375, 1.6826171875, 0.62255859375, 2.189453125, -0.599609375, -0.65576171875, -1.1005859375, -0.45263671875, -0.292236328125, 2.58203125, -1.3779296875, 0.23486328125, -1.708984375, -1.4111328125, -0.5078125, -0.8525390625, -0.90771484375, 0.861328125, -2.22265625, -1.380859375, 0.7275390625, 0.85595703125, -0.77978515625, 2.044921875, -0.430908203125, 0.78857421875, -1.21484375, -0.09130859375, 0.5146484375, -1.92578125, -0.1396484375, 0.289306640625, 0.60498046875, 0.93896484375, -0.09295654296875, -0.45751953125, -0.986328125, -0.66259765625, 1.48046875, 0.274169921875, -0.267333984375, -1.3017578125, -1.3623046875, -1.982421875, -0.86083984375, -0.41259765625, -0.2939453125, -1.91015625, 1.6826171875, 0.437255859375, 1.0029296875, 0.376220703125, -0.010467529296875, -0.82861328125, -0.513671875, -3.134765625, 1.0205078125, -1.26171875, -1.009765625, 1.0869140625, -0.95703125, 0.0103759765625, 1.642578125, 0.78564453125, 1.029296875, 0.496826171875, 1.2880859375, 0.5234375, 0.05322265625, -0.206787109375, -0.79443359375, -1.1669921875, 0.049530029296875, -0.27978515625, 0.0237884521484375, -0.74169921875, -1.068359375, 0.86083984375, 1.1787109375, 0.91064453125, -0.453857421875, -1.822265625, -0.9228515625, -0.50048828125, 0.359130859375, 0.802734375, -1.3564453125, -0.322509765625, -1.1123046875, -1.0390625, -0.52685546875, -1.291015625, -0.343017578125, -1.2109375, -0.19091796875, 2.146484375, -0.04315185546875, -0.3701171875, -2.044921875, -0.429931640625, -0.56103515625, -0.166015625, -0.4658203125, -2.29296875, -1.078125, -1.0927734375, -0.1033935546875, -0.56103515625, -0.05743408203125, -1.986328125, -0.513671875, 0.70361328125, -2.484375, -1.3037109375, -1.6650390625, 0.4814453125, -0.84912109375, -2.697265625, -0.197998046875, 0.0869140625, -0.172607421875, -1.326171875, -1.197265625, 1.23828125, -0.38720703125, -0.075927734375, 0.02569580078125, -1.2119140625, 0.09027099609375, -2.12890625, -1.640625, -0.1524658203125, 0.2373046875, 1.37109375, 2.248046875, 1.4619140625, 0.3134765625, 0.50244140625, -0.1383056640625, -1.2705078125, 0.7353515625, 0.65771484375, -0.431396484375, -1.341796875, 0.10089111328125, 0.208984375, -0.0099945068359375, 0.83203125, 1.314453125, -0.422607421875, -1.58984375, -0.6044921875, 0.23681640625, -1.60546875, -0.61083984375, -1.5615234375, 1.62890625, -0.6728515625, -0.68212890625, -0.5224609375, -0.9150390625, -0.468994140625, 0.268310546875, 0.287353515625, -0.025543212890625, 0.443603515625, 1.62109375, -1.08984375, -0.5556640625, 1.03515625, -0.31298828125, -0.041778564453125, 0.260986328125, 0.34716796875, -2.326171875, 0.228271484375, -0.85107421875, -2.255859375, 0.3486328125, -0.25830078125, -0.3671875, -0.796875, -1.115234375, 1.8369140625, -0.19775390625, -1.236328125, -0.0447998046875, 0.69921875, 1.37890625, 1.11328125, 0.0928955078125, 0.6318359375, -0.62353515625, 0.55859375, -0.286865234375, 1.5361328125, -0.391357421875, -0.052215576171875, -1.12890625, 0.55517578125, -0.28515625, -0.3603515625, 0.68896484375, 0.67626953125, 0.003070831298828125, 1.2236328125, 0.1597900390625, -1.3076171875, 0.99951171875, -2.5078125, -1.2119140625, 0.1749267578125, -1.1865234375, -1.234375, -0.1180419921875, -1.751953125, 0.033050537109375, 0.234130859375, -3.107421875, -1.0380859375, 0.61181640625, -0.87548828125, 0.3154296875, -1.103515625, 0.261474609375, -1.130859375, -0.7470703125, -0.43408203125, 1.3828125, -0.41259765625, -1.7587890625, 0.765625, 0.004852294921875, 0.135498046875, -0.76953125, -0.1314697265625, 0.400390625, 1.43359375, 0.07135009765625, 0.0645751953125, -0.5869140625, -0.5810546875, -0.2900390625, -1.3037109375, 0.1287841796875, -0.27490234375, 0.59228515625, 2.333984375, -0.54541015625, -0.556640625, 0.447265625, -0.806640625, 0.09149169921875, -0.70654296875, -0.357177734375, -1.099609375, -0.5576171875, -0.44189453125, 0.400390625, -0.666015625, -1.4619140625, 0.728515625, -1.5986328125, 0.153076171875, -0.126708984375, -2.83984375, -1.84375, -0.2469482421875, 0.677734375, 0.43701171875, 3.298828125, 1.1591796875, -0.7158203125, -0.8251953125, 0.451171875, -2.376953125, -0.58642578125, -0.86767578125, 0.0789794921875, 0.1351318359375, -0.325439453125, 0.484375, 1.166015625, -0.1610107421875, -0.15234375, -0.54638671875, -0.806640625, 0.285400390625, 0.1661376953125, -0.50146484375, -1.0478515625, 1.5751953125, 0.0313720703125, 0.2396240234375, -0.6572265625, -0.1258544921875, -1.060546875, 1.3076171875, -0.301513671875, -1.2412109375, 0.6376953125, -1.5693359375, 0.354248046875, 0.2427978515625, -0.392333984375, 0.61962890625, -0.58837890625, -1.71484375, -0.2098388671875, -0.828125, 0.330810546875, 0.16357421875, -0.2259521484375, 0.0972900390625, -0.451416015625, 1.79296875, -1.673828125, -1.58203125, -2.099609375, -0.487548828125, -0.87060546875, 0.62646484375, -1.470703125, -0.1558837890625, 0.4609375, 1.3369140625, 0.2322998046875, 0.1632080078125, 0.65966796875, 1.0810546875, 0.1041259765625, 0.63232421875, -0.32421875, -1.04296875, -1.046875, -1.3720703125, -0.8486328125, 0.1290283203125, 0.137939453125, 0.1549072265625, -1.0908203125, 0.0167694091796875, -0.31689453125, 1.390625, 0.07269287109375, 1.0390625, 1.1162109375, -0.455810546875, -0.06689453125, -0.053741455078125, 0.5048828125, -0.8408203125, -1.19921875, 0.87841796875, 0.7421875, 0.2030029296875, 0.109619140625, -0.59912109375, -1.337890625, -0.74169921875, -0.64453125, -1.326171875, 0.21044921875, -1.3583984375, -1.685546875, -0.472900390625, -0.270263671875, 0.99365234375, -0.96240234375, 1.1279296875, -0.45947265625, -0.45654296875, -0.99169921875, -3.515625, -1.9853515625, 0.73681640625, 0.92333984375, -0.56201171875, -1.4453125, -2.078125, 0.94189453125, -1.333984375, 0.0982666015625, 0.60693359375, 0.367431640625, 3.015625, -1.1357421875, -1.5634765625, 0.90234375, -0.1783447265625, 0.1802978515625, -0.317138671875, -0.513671875, 1.2353515625, -0.033203125, 1.4482421875, 1.0087890625, 0.9248046875, 0.10418701171875, 0.7626953125, -1.3798828125, 0.276123046875, 0.55224609375, 1.1005859375, -0.62158203125, -0.806640625, 0.65087890625, 0.270263671875, -0.339111328125, -0.9384765625, -0.09381103515625, -0.7216796875, 1.37890625, -0.398193359375, -0.3095703125, -1.4912109375, 0.96630859375, 0.43798828125, 0.62255859375, 0.0213470458984375, 0.235595703125, -1.2958984375, 0.0157318115234375, -0.810546875, 1.9736328125, -0.2462158203125, 0.720703125, 0.822265625, -0.755859375, -0.658203125, 0.344482421875, -2.892578125, -0.282470703125, 1.2529296875, -0.294189453125, 0.6748046875, -0.80859375, 0.9287109375, 1.27734375, -1.71875, -0.166015625, 0.47412109375, -0.41259765625, -1.3681640625, -0.978515625, -0.77978515625, -1.044921875, -0.90380859375, -0.08184814453125, -0.86181640625, -0.10772705078125, -0.299560546875, -0.4306640625, -0.47119140625, 0.95703125, 1.107421875, 0.91796875, 0.76025390625, 0.7392578125, -0.09161376953125, -0.7392578125, 0.9716796875, -0.395751953125, -0.75390625, -0.164306640625, -0.087646484375, 0.028564453125, -0.91943359375, -0.66796875, 2.486328125, 0.427734375, 0.626953125, 0.474853515625, 0.0926513671875, 0.830078125, -0.6923828125, 0.7841796875, -0.89208984375, -2.482421875, 0.034912109375, -1.3447265625, -0.475341796875, -0.286376953125, -0.732421875, 0.190673828125, -0.491455078125, -3.091796875, -1.2783203125, -0.66015625, -0.1507568359375, 0.042236328125, -1.025390625, 0.12744140625, -1.984375, -0.393798828125, -1.25, -1.140625, 1.77734375, 0.2457275390625, -0.8017578125, 0.7763671875, -0.387939453125, -0.3662109375, 1.1572265625, 0.123291015625, -0.07135009765625, 1.412109375, -0.685546875, -3.078125, 0.031524658203125, -0.70458984375, 0.78759765625, 0.433837890625, -1.861328125, -1.33203125, 2.119140625, -1.3544921875, -0.6591796875, -1.4970703125, 0.40625, -2.078125, -1.30859375, 0.050262451171875, -0.60107421875, 1.0078125, 0.05657958984375, -0.96826171875, 0.0264892578125, 0.159912109375, 0.84033203125, -1.1494140625, -0.0433349609375, -0.2034912109375, 1.09765625, -1.142578125, -0.283203125, -0.427978515625, 1.0927734375, -0.67529296875, -0.61572265625, 2.517578125, 0.84130859375, 1.8662109375, 0.1748046875, -0.407958984375, -0.029449462890625, -0.27587890625, -0.958984375, -0.10028076171875, 1.248046875, -0.0792236328125, -0.45556640625, 0.7685546875, 1.5556640625, -1.8759765625, -0.131591796875, -1.3583984375, 0.7890625, 0.80810546875, -1.0322265625, -0.53076171875, -0.1484375, -1.7841796875, -1.2470703125, 0.17138671875, -0.04864501953125, -0.80322265625, -0.0933837890625, 0.984375, 0.7001953125, 0.5380859375, 0.2022705078125, -1.1865234375, 0.5439453125, 1.1318359375, 0.79931640625, 0.32666015625, -1.26171875, 0.457763671875, 1.1591796875, -0.34423828125, 0.65771484375, 0.216552734375, 1.19140625, -0.2744140625, -0.020416259765625, -0.86376953125, 0.93017578125, 1.0556640625, 0.69873046875, -0.15087890625, -0.33056640625, 0.8505859375, 0.06890869140625, 0.359375, -0.262939453125, 0.12493896484375, 0.017059326171875, -0.98974609375, 0.5107421875, 0.2408447265625, 0.615234375, -0.62890625, 0.86962890625, -0.07427978515625, 0.85595703125, 0.300537109375, -1.072265625, -1.6064453125, -0.353515625, -0.484130859375, -0.6044921875, -0.455810546875, 0.95849609375, 1.3671875, 0.544921875, 0.560546875, 0.34521484375, -0.6513671875, -0.410400390625, -0.2021484375, -0.1656494140625, 0.073486328125, 0.84716796875, -1.7998046875, -1.0126953125, -0.1324462890625, 0.95849609375, -0.669921875, -0.79052734375, -2.193359375, -0.42529296875, -1.7275390625, -1.04296875, 0.716796875, -0.4423828125, -1.193359375, 0.61572265625, -1.5224609375, 0.62890625, -0.705078125, 0.677734375, -0.213134765625, -1.6748046875, -1.087890625, -0.65185546875, -1.1337890625, 2.314453125, -0.352783203125, -0.27001953125, -2.01953125, -1.2685546875, 0.308837890625, -0.280517578125, -1.3798828125, -1.595703125, 0.642578125, 1.693359375, -0.82470703125, -1.255859375, 0.57373046875, 1.5859375, 1.068359375, -0.876953125, 0.370849609375, 1.220703125, 0.59765625, 0.007602691650390625, 0.09326171875, -0.9521484375, -0.024932861328125, -0.94775390625, -0.299560546875, -0.002536773681640625, 1.41796875, -0.06903076171875, -1.5927734375, 0.353515625, 3.63671875, -0.765625, -1.1142578125, 0.4287109375, -0.86865234375, -0.9267578125, -0.21826171875, -1.10546875, 0.29296875, -0.225830078125, 0.5400390625, -0.45556640625, -0.68701171875, -0.79150390625, -1.0810546875, 0.25439453125, -1.2998046875, -0.494140625, -0.1510009765625, 1.5615234375, -0.4248046875, -0.486572265625, 0.45458984375, 0.047637939453125, -0.11639404296875, 0.057403564453125, 0.130126953125, -0.10125732421875, -0.56201171875, 1.4765625, -1.7451171875, 1.34765625, -0.45703125, 0.873046875, -0.056121826171875, -0.8876953125, -0.986328125, 1.5654296875, 0.49853515625, 0.55859375, -0.2198486328125, 0.62548828125, 0.2734375, -0.63671875, -0.41259765625, -1.2705078125, 0.0665283203125, 1.3369140625, 0.90283203125, -0.77685546875, -1.5, -1.8525390625, -1.314453125, -0.86767578125, -0.331787109375, 0.1590576171875, 0.94775390625, -0.1771240234375, 1.638671875, -2.17578125, 0.58740234375, 0.424560546875, -0.3466796875, 0.642578125, 0.473388671875, 0.96435546875, 1.38671875, -0.91357421875, 1.0361328125, -0.67333984375, 1.5009765625]]]).to(device) + + cond = [[prompt_embeds, {}]] + + return (cond,) + +NODE_CLASS_MAPPINGS = { + "LotusConditioning" : LotusConditioning, +} diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index ceac5654..2b805c1e 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -24,6 +24,10 @@ class X0(comfy.model_sampling.EPS): def calculate_denoised(self, sigma, model_output, model_input): return model_output +class Lotus(X0): + def calculate_input(self, sigma, noise): + return noise + class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 @@ -56,7 +60,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "lotus"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -78,6 +82,8 @@ class ModelSamplingDiscrete: sampling_base = ModelSamplingDiscreteDistilled elif sampling == "x0": sampling_type = X0 + elif sampling == "lotus": + sampling_type = Lotus class ModelSamplingAdvanced(sampling_base, sampling_type): pass diff --git a/nodes.py b/nodes.py index a9c931df..27ef743b 100644 --- a/nodes.py +++ b/nodes.py @@ -2264,6 +2264,7 @@ def init_builtin_extra_nodes(): "nodes_video.py", "nodes_lumina2.py", "nodes_wan.py", + "nodes_lotus.py", "nodes_hunyuan3d.py", "nodes_primitive.py", ] From d9fa9d307ff49d3bad50b623306118d483a387fd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 21 Mar 2025 14:19:37 -0400 Subject: [PATCH 10/37] Automatically set the right sampling type for lotus. --- comfy/model_base.py | 5 ++++- comfy/model_sampling.py | 9 +++++++++ comfy_extras/nodes_model_advanced.py | 16 ++++------------ 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2fb4b145..eec70d5d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -59,6 +59,7 @@ class ModelType(Enum): FLOW = 6 V_PREDICTION_CONTINUOUS = 7 FLUX = 8 + IMG_TO_IMG = 9 from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV @@ -89,6 +90,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLUX: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingFlux + elif model_type == ModelType.IMG_TO_IMG: + c = comfy.model_sampling.IMG_TO_IMG class ModelSampling(s, c): pass @@ -613,7 +616,7 @@ class Lotus(BaseModel): out['y'] = comfy.conds.CONDRegular(task_emb) return out - def __init__(self, model_config, model_type=ModelType.EPS, device=None): + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None): super().__init__(model_config, model_type, device=device) class StableCascade_C(BaseModel): diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index ff27b09a..b79af1e9 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -69,6 +69,15 @@ class CONST: sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1)) return latent / (1.0 - sigma) +class X0(EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + +class IMG_TO_IMG(X0): + def calculate_input(self, sigma, noise): + return noise + + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): super().__init__() diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 2b805c1e..71a652ff 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -20,14 +20,6 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class X0(comfy.model_sampling.EPS): - def calculate_denoised(self, sigma, model_output, model_input): - return model_output - -class Lotus(X0): - def calculate_input(self, sigma, noise): - return noise - class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 @@ -60,7 +52,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0", "lotus"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -81,9 +73,9 @@ class ModelSamplingDiscrete: sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled elif sampling == "x0": - sampling_type = X0 - elif sampling == "lotus": - sampling_type = Lotus + sampling_type = comfy.model_sampling.X0 + elif sampling == "img_to_img": + sampling_type = comfy.model_sampling.IMG_TO_IMG class ModelSamplingAdvanced(sampling_base, sampling_type): pass From 2206246055af7996ee8c6cb79346767d90da8372 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 21 Mar 2025 16:24:13 -0400 Subject: [PATCH 11/37] support output normal and lineart once (#7290) --- comfy_extras/nodes_load_3d.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 8b43cf21..db30030f 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -21,8 +21,8 @@ class Load3D(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING") - RETURN_NAMES = ("image", "mask", "mesh_path") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart") FUNCTION = "process" EXPERIMENTAL = True @@ -32,12 +32,16 @@ class Load3D(): def process(self, model_file, image, **kwargs): image_path = folder_paths.get_annotated_filepath(image['image']) mask_path = folder_paths.get_annotated_filepath(image['mask']) + normal_path = folder_paths.get_annotated_filepath(image['normal']) + lineart_path = folder_paths.get_annotated_filepath(image['lineart']) load_image_node = nodes.LoadImage() output_image, ignore_mask = load_image_node.load_image(image=image_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path) + normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) + lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) - return output_image, output_mask, model_file, + return output_image, output_mask, model_file, normal_image, lineart_image class Load3DAnimation(): @classmethod @@ -55,8 +59,8 @@ class Load3DAnimation(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING") - RETURN_NAMES = ("image", "mask", "mesh_path") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal") FUNCTION = "process" EXPERIMENTAL = True @@ -66,12 +70,14 @@ class Load3DAnimation(): def process(self, model_file, image, **kwargs): image_path = folder_paths.get_annotated_filepath(image['image']) mask_path = folder_paths.get_annotated_filepath(image['mask']) + normal_path = folder_paths.get_annotated_filepath(image['normal']) load_image_node = nodes.LoadImage() output_image, ignore_mask = load_image_node.load_image(image=image_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path) + normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - return output_image, output_mask, model_file, + return output_image, output_mask, model_file, normal_image class Preview3D(): @classmethod From ce9b084279110f78ca2faf53fb0ef05ac4aaba48 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 21 Mar 2025 19:08:25 -0400 Subject: [PATCH 12/37] [nit] Format error strings (#7345) --- app/frontend_management.py | 53 +++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index b4ba994d..c56ea86e 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -22,13 +22,21 @@ import app.logger # The path to the requirements.txt file req_path = Path(__file__).parents[1] / "requirements.txt" + def frontend_install_warning_message(): """The warning message to display when the frontend version is not up to date.""" extra = "" if sys.flags.no_user_site: extra = "-s " - return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem" + return f""" +Please install the updated requirements.txt file by running: +{sys.executable} {extra}-m pip install -r {req_path} + +This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. + +If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem +""".strip() def check_frontend_version(): @@ -43,7 +51,17 @@ def check_frontend_version(): with open(req_path, "r", encoding="utf-8") as f: required_frontend = parse_version(f.readline().split("=")[-1]) if frontend_version < required_frontend: - app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message())) + app.logger.log_startup_warning( + f""" +________________________________________________________________________ +WARNING WARNING WARNING WARNING WARNING + +Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}. + +{frontend_install_warning_message()} +________________________________________________________________________ +""".strip() + ) else: logging.info("ComfyUI frontend version: {}".format(frontend_version_str)) except Exception as e: @@ -150,9 +168,20 @@ class FrontendManager: def default_frontend_path(cls) -> str: try: import comfyui_frontend_package + return str(importlib.resources.files(comfyui_frontend_package) / "static") except ImportError: - logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n") + logging.error( + f""" +********** ERROR *********** + +comfyui-frontend-package is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) sys.exit(-1) @classmethod @@ -175,7 +204,9 @@ class FrontendManager: return match_result.group(1), match_result.group(2), match_result.group(3) @classmethod - def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: + def init_frontend_unsafe( + cls, version_string: str, provider: Optional[FrontEndProvider] = None + ) -> str: """ Initializes the frontend for the specified version. @@ -197,12 +228,20 @@ class FrontendManager: repo_owner, repo_name, version = cls.parse_version_string(version_string) if version.startswith("v"): - expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v")) + expected_path = str( + Path(cls.CUSTOM_FRONTENDS_ROOT) + / f"{repo_owner}_{repo_name}" + / version.lstrip("v") + ) if os.path.exists(expected_path): - logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}") + logging.info( + f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}" + ) return expected_path - logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...") + logging.info( + f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..." + ) provider = provider or FrontEndProvider(repo_owner, repo_name) release = provider.get_release(version) From 75c1c757d90ca891eff823893248ef8b51d31d01 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 21 Mar 2025 20:09:54 -0400 Subject: [PATCH 13/37] ComfyUI version v0.3.27 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b5e6fbea..70562252 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.26" +__version__ = "0.3.27" diff --git a/pyproject.toml b/pyproject.toml index f13fed8d..db9e776c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.26" +version = "0.3.27" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From e471c726e57b3854e0dd47efe0e7c53a28703dbb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 22 Mar 2025 15:45:56 -0400 Subject: [PATCH 14/37] Fallback to pytorch attention if sage attention fails. --- comfy/ldm/modules/attention.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7908d131..ede50646 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -471,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape - tensor_layout="HND" + tensor_layout = "HND" else: b, _, dim_head = q.shape dim_head //= heads @@ -479,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.view(b, -1, heads, dim_head), (q, k, v), ) - tensor_layout="NHD" + tensor_layout = "NHD" if mask is not None: # add a batch dimension if there isn't already one @@ -489,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= if mask.ndim == 3: mask = mask.unsqueeze(1) - out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + try: + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + except Exception as e: + logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + if tensor_layout == "NHD": + q, k, v = map( + lambda t: t.transpose(1, 2), + (q, k, v), + ) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + if tensor_layout == "HND": if not skip_output_reshape: out = ( From 581a9991ff641ef330a2977d5b92e682c9c3df95 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 23 Mar 2025 08:06:36 -0400 Subject: [PATCH 15/37] Add model merging node for WAN 2.1 --- .../nodes_model_merging_model_specific.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index 3e37f70d..dc341194 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb." + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["patch_embedding."] = argument + arg_dict["time_embedding."] = argument + arg_dict["time_projection."] = argument + arg_dict["text_embedding."] = argument + arg_dict["img_emb."] = argument + + for i in range(40): + arg_dict["blocks.{}.".format(i)] = argument + + arg_dict["head."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks @@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeLTXV": ModelMergeLTXV, "ModelMergeCosmos7B": ModelMergeCosmos7B, "ModelMergeCosmos14B": ModelMergeCosmos14B, + "ModelMergeWAN2_1": ModelMergeWAN2_1, } From eade1551bbd8678a7883d7061de73264cc279abf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Mar 2025 07:14:32 -0400 Subject: [PATCH 16/37] Add Hunyuan3D to readme. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index a807ea9d..a99aca0e 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) +- 3D Models + - [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2) - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. From 8edc1f44c1312d58afb6b0d817181079d39681e7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Mar 2025 05:23:49 -0400 Subject: [PATCH 17/37] Support more float8 types. --- comfy/model_management.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2a9b022b..f1ecfc20 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -46,6 +46,32 @@ cpu_state = CPUState.GPU total_vram = 0 +def get_supported_float8_types(): + float8_types = [] + try: + float8_types.append(torch.float8_e4m3fn) + except: + pass + try: + float8_types.append(torch.float8_e4m3fnuz) + except: + pass + try: + float8_types.append(torch.float8_e5m2) + except: + pass + try: + float8_types.append(torch.float8_e5m2fnuz) + except: + pass + try: + float8_types.append(torch.float8_e8m0fnu) + except: + pass + return float8_types + +FLOAT8_TYPES = get_supported_float8_types() + xpu_available = False torch_version = "" try: @@ -701,11 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor return torch.float8_e5m2 fp8_dtype = None - try: - if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - fp8_dtype = weight_dtype - except: - pass + if weight_dtype in FLOAT8_TYPES: + fp8_dtype = weight_dtype if fp8_dtype is not None: if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive From 84fdaf7b0ef4d030723bc3b350282dc6c92743f6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Mar 2025 05:08:49 -0400 Subject: [PATCH 18/37] Add CFGZeroStar node. Works on all models that use a negative prompt but is meant for rectified flow models. --- comfy_extras/nodes_cfg.py | 45 +++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 46 insertions(+) create mode 100644 comfy_extras/nodes_cfg.py diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py new file mode 100644 index 00000000..1fb68664 --- /dev/null +++ b/comfy_extras/nodes_cfg.py @@ -0,0 +1,45 @@ +import torch + +# https://github.com/WeichenFan/CFG-Zero-star +def optimized_scale(positive, negative): + positive_flat = positive.reshape(positive.shape[0], -1) + negative_flat = negative.reshape(negative.shape[0], -1) + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) + +class CFGZeroStar: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "advanced/guidance" + + def patch(self, model): + m = model.clone() + def cfg_zero_star(args): + guidance_scale = args['cond_scale'] + x = args['input'] + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + out = args["denoised"] + alpha = optimized_scale(x - cond_p, x - uncond_p) + + return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) + m.set_model_sampler_post_cfg_function(cfg_zero_star) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "CFGZeroStar": CFGZeroStar +} diff --git a/nodes.py b/nodes.py index 27ef743b..272c2a25 100644 --- a/nodes.py +++ b/nodes.py @@ -2267,6 +2267,7 @@ def init_builtin_extra_nodes(): "nodes_lotus.py", "nodes_hunyuan3d.py", "nodes_primitive.py", + "nodes_cfg.py", ] import_failed = [] From 3661c833bcc41b788a7c9f0e7bc48524f8ee5f82 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Mar 2025 19:54:54 -0400 Subject: [PATCH 19/37] Support the WAN 2.1 fun control models. Use the new WanFunControlToVideo node. --- comfy/model_base.py | 17 ++++++++----- comfy/supported_models.py | 14 ++++++++++- comfy_extras/nodes_wan.py | 51 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 7 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index eec70d5d..315b5d1e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -992,7 +992,8 @@ class WAN21(BaseModel): def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) - if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]: + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels == 0: return None image = kwargs.get("concat_latent_image", None) @@ -1000,12 +1001,16 @@ class WAN21(BaseModel): if image is None: image = torch.zeros_like(noise) + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], 16): + image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + image = utils.resize_to_batch_size(image, noise.shape[0]) - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - image = self.process_latent_in(image) - image = utils.resize_to_batch_size(image, noise.shape[0]) - - if not self.image_to_video: + if not self.image_to_video or extra_channels == image.shape[1]: return image mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fad00d35..2a6a6156 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -969,12 +969,24 @@ class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", "model_type": "i2v", + "in_dim": 36, } def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, image_to_video=True, device=device) return out +class WAN21_FunControl2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + "in_dim": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1013,6 +1025,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index dc30eb54..428874bc 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -3,6 +3,7 @@ import node_helpers import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class WanImageToVideo: @@ -49,6 +50,56 @@ class WanImageToVideo: return (positive, negative, out_latent) +class WanFunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, + "WanFunControlToVideo": WanFunControlToVideo, } From 0a1f8869c9998bbfcfeb2e97aa96a6d3e0a2b5df Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 27 Mar 2025 11:13:27 -0400 Subject: [PATCH 20/37] Add WanFunInpaintToVideo node for the Wan fun inpaint models. --- comfy/model_base.py | 7 +++-- comfy_extras/nodes_wan.py | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 315b5d1e..8f588e2b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1017,11 +1017,14 @@ class WAN21(BaseModel): if mask is None: mask = torch.zeros_like(noise)[:, :4] else: - mask = 1.0 - torch.mean(mask, dim=1, keepdim=True) + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) - mask = mask.repeat(1, 4, 1, 1, 1) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) return torch.cat((mask, image), dim=1) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 428874bc..2d0f31ac 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -99,7 +99,61 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class WanFunInpaintToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, + "WanFunInpaintToVideo": WanFunInpaintToVideo, } From a40fcfc2d5392a5014cd87588035ebce194cb015 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 28 Mar 2025 02:27:01 -0400 Subject: [PATCH 21/37] Update frontend to 1.14.6 (#7416) Cherry-pick the fix: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3252 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c78d3c22..806fbc75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.14.5 +comfyui-frontend-package==1.14.6 torch torchsde torchvision From 2d17d8910c7d34383feaf1aaac8d08571fe42077 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Mar 2025 08:40:25 -0400 Subject: [PATCH 22/37] Don't error if wan concat image has extra channels. --- comfy/model_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8f588e2b..f55cbe18 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1013,6 +1013,9 @@ class WAN21(BaseModel): if not self.image_to_video or extra_channels == image.shape[1]: return image + if image.shape[1] > (extra_channels - 4): + image = image[:, :(extra_channels - 4)] + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.zeros_like(noise)[:, :4] From 832fc02330c1843b9817b8ee90b061d2298a5911 Mon Sep 17 00:00:00 2001 From: Michael Kupchick Date: Sun, 30 Mar 2025 03:03:02 +0300 Subject: [PATCH 23/37] ltxv: fix preprocessing exception when compression is 0. (#7431) --- comfy_extras/nodes_lt.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index fdc6c7c1..52588920 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -446,10 +446,9 @@ class LTXVPreprocess: CATEGORY = "image" def preprocess(self, image, img_compression): - if img_compression > 0: - output_images = [] - for i in range(image.shape[0]): - output_images.append(preprocess(image[i], img_compression)) + output_images = [] + for i in range(image.shape[0]): + output_images.append(preprocess(image[i], img_compression)) return (torch.stack(output_images),) From a3100c8452862e914996648e0fbc56098ab26b60 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Mar 2025 20:11:43 -0400 Subject: [PATCH 24/37] Remove useless code. --- comfy/model_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index f55cbe18..6bc627ae 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1000,7 +1000,6 @@ class WAN21(BaseModel): device = kwargs["device"] if image is None: - image = torch.zeros_like(noise) shape_image = list(noise.shape) shape_image[1] = extra_channels image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) From 0b4584c7413f1c3f6a34875a790c0381b3510447 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Mar 2025 21:47:05 -0400 Subject: [PATCH 25/37] Fix latent composite node not working when source has alpha. --- comfy_extras/nodes_mask.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 63fd13b9..2dd826b2 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -87,6 +87,8 @@ class ImageCompositeMasked: CATEGORY = "image" def composite(self, destination, source, x, y, resize_source, mask = None): + if destination.shape[-1] < source.shape[-1]: + source = source[...,:destination.shape[-1]] destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) From 548457bac47bb6c0ce233a9f5abb3467582d710d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 31 Mar 2025 20:59:12 -0400 Subject: [PATCH 26/37] Fix alpha channel mismatch on destination in ImageCompositeMasked --- comfy_extras/nodes_mask.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 2dd826b2..e1f0c822 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -89,6 +89,9 @@ class ImageCompositeMasked: def composite(self, destination, source, x, y, resize_source, mask = None): if destination.shape[-1] < source.shape[-1]: source = source[...,:destination.shape[-1]] + elif destination.shape[-1] > source.shape[-1]: + destination = torch.nn.functional.pad(destination, (0, 1)) + destination[..., -1] = source[..., -1] destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) From 301e26b131e99577aa64a366ca93c2bf85f34b96 Mon Sep 17 00:00:00 2001 From: BVH <82035780+bvhari@users.noreply.github.com> Date: Tue, 1 Apr 2025 23:18:53 +0530 Subject: [PATCH 27/37] Add option to store TE in bf16 (#7461) --- comfy/cli_args.py | 1 + comfy/model_management.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 91c1fe70..62079e6a 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -79,6 +79,7 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") +fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") diff --git a/comfy/model_management.py b/comfy/model_management.py index f1ecfc20..84a260fc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -823,6 +823,8 @@ def text_encoder_dtype(device=None): return torch.float8_e5m2 elif args.fp16_text_enc: return torch.float16 + elif args.bf16_text_enc: + return torch.bfloat16 elif args.fp32_text_enc: return torch.float32 From 2b71aab29903c3d26d71f9ca2a034442a419ab0a Mon Sep 17 00:00:00 2001 From: Laurent Erignoux Date: Wed, 2 Apr 2025 01:53:52 +0800 Subject: [PATCH 28/37] User missing (#7439) * Ensuring a 401 error is returned when user data is not found in multi-user context. * Returning a 401 error when provided comfy-user does not exists on server side. --- app/app_settings.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/app/app_settings.py b/app/app_settings.py index a545df92..c7ac73bf 100644 --- a/app/app_settings.py +++ b/app/app_settings.py @@ -9,8 +9,14 @@ class AppSettings(): self.user_manager = user_manager def get_settings(self, request): - file = self.user_manager.get_request_user_filepath( - request, "comfy.settings.json") + try: + file = self.user_manager.get_request_user_filepath( + request, + "comfy.settings.json" + ) + except KeyError as e: + logging.error("User settings not found.") + raise web.HTTPUnauthorized() from e if os.path.isfile(file): try: with open(file) as f: From ab5413351eee61f3d7f10c74e75286df0058bb18 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 1 Apr 2025 14:09:31 -0400 Subject: [PATCH 29/37] Fix comment. This function does not support quads. --- comfy_extras/nodes_hunyuan3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 1ca7c2fe..5adc6b65 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -244,7 +244,7 @@ def save_glb(vertices, faces, filepath, metadata=None): Parameters: vertices: torch.Tensor of shape (N, 3) - The vertex coordinates - faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces) + faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces) filepath: str - Output filepath (should end with .glb) """ From 2222cf67fdb2a3b805c622f7e309a6db2bb04d19 Mon Sep 17 00:00:00 2001 From: BiologicalExplosion <49753622+BiologicalExplosion@users.noreply.github.com> Date: Thu, 3 Apr 2025 07:24:04 +0800 Subject: [PATCH 30/37] MLU memory optimization (#7470) Co-authored-by: huzhan --- comfy/model_management.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 84a260fc..19e6c8df 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1237,6 +1237,8 @@ def soft_empty_cache(force=False): torch.xpu.empty_cache() elif is_ascend_npu(): torch.npu.empty_cache() + elif is_mlu(): + torch.mlu.empty_cache() elif torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() From 3d2e3a6f29670809aa97b41505fa4e93ce11b98d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 2 Apr 2025 19:32:34 -0400 Subject: [PATCH 31/37] Fix alpha image issue in more nodes. --- comfy_extras/nodes_mask.py | 7 ++----- comfy_extras/nodes_post_processing.py | 3 ++- node_helpers.py | 8 ++++++++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index e1f0c822..13d2b4ba 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -2,6 +2,7 @@ import numpy as np import scipy.ndimage import torch import comfy.utils +import node_helpers from nodes import MAX_RESOLUTION @@ -87,11 +88,7 @@ class ImageCompositeMasked: CATEGORY = "image" def composite(self, destination, source, x, y, resize_source, mask = None): - if destination.shape[-1] < source.shape[-1]: - source = source[...,:destination.shape[-1]] - elif destination.shape[-1] > source.shape[-1]: - destination = torch.nn.functional.pad(destination, (0, 1)) - destination[..., -1] = source[..., -1] + destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 68f6ef51..5b954201 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -6,7 +6,7 @@ import math import comfy.utils import comfy.model_management - +import node_helpers class Blend: def __init__(self): @@ -34,6 +34,7 @@ class Blend: CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) diff --git a/node_helpers.py b/node_helpers.py index 48da3b09..4f805387 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -44,3 +44,11 @@ def string_to_torch_dtype(string): return torch.float16 if string == "bf16": return torch.bfloat16 + +def image_alpha_fix(destination, source): + if destination.shape[-1] < source.shape[-1]: + source = source[...,:destination.shape[-1]] + elif destination.shape[-1] > source.shape[-1]: + destination = torch.nn.functional.pad(destination, (0, 1)) + destination[..., -1] = source[..., -1] + return destination, source From 721253cb0527e0476f12bd20835b4fff5961508e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 3 Apr 2025 20:57:59 -0400 Subject: [PATCH 32/37] Fix problem. --- node_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/node_helpers.py b/node_helpers.py index 4f805387..c3e1a14c 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -50,5 +50,5 @@ def image_alpha_fix(destination, source): source = source[...,:destination.shape[-1]] elif destination.shape[-1] > source.shape[-1]: destination = torch.nn.functional.pad(destination, (0, 1)) - destination[..., -1] = source[..., -1] + destination[..., -1] = 1.0 return destination, source From 3a100b9a550b9700d08eecb006b5accd65863925 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 4 Apr 2025 21:24:56 -0400 Subject: [PATCH 33/37] Disable partial offloading of audio VAE. --- comfy/sd.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d096f496..4d3aef3e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -265,6 +265,7 @@ class VAE: self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] + self.disable_offload = False self.downscale_index_formula = None self.upscale_index_formula = None @@ -337,6 +338,7 @@ class VAE: self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.disable_offload = True elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae if "blocks.2.blocks.3.stack.5.weight" in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."}) @@ -515,7 +517,7 @@ class VAE: pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -544,7 +546,7 @@ class VAE: def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) dims = samples.ndim - 2 args = {} if tile_x is not None: @@ -578,7 +580,7 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) @@ -612,7 +614,7 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) args = {} if tile_x is not None: From 89e4ea01754fc043913ac164f5b7880ec58ebab9 Mon Sep 17 00:00:00 2001 From: Raphael Walker Date: Sat, 5 Apr 2025 03:27:54 +0200 Subject: [PATCH 34/37] Add activations_shape info in UNet models (#7482) * Add activations_shape info in UNet models * activations_shape should be a list --- comfy/ldm/modules/attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ede50646..45f9e311 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -847,6 +847,7 @@ class SpatialTransformer(nn.Module): if not isinstance(context, list): context = [context] * len(self.transformer_blocks) b, c, h, w = x.shape + transformer_options["activations_shape"] = list(x.shape) x_in = x x = self.norm(x) if not self.use_linear: @@ -962,6 +963,7 @@ class SpatialVideoTransformer(SpatialTransformer): transformer_options={} ) -> torch.Tensor: _, _, h, w = x.shape + transformer_options["activations_shape"] = list(x.shape) x_in = x spatial_context = None if exists(context): From 3bfe4e527665d71a3cc88fe06e2733209602ae3a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 5 Apr 2025 06:14:10 -0400 Subject: [PATCH 35/37] Support 512 siglip model. --- comfy/clip_vision.py | 8 ++++++-- comfy/clip_vision_siglip_512.json | 13 +++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 comfy/clip_vision_siglip_512.json diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 87d32a66..11bc5778 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -110,9 +110,13 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: + embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") - elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: + if embed_shape == 729: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif embed_shape == 1024: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") + elif embed_shape == 577: if "multi_modal_projector.linear_1.bias" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") else: diff --git a/comfy/clip_vision_siglip_512.json b/comfy/clip_vision_siglip_512.json new file mode 100644 index 00000000..7fb93ce1 --- /dev/null +++ b/comfy/clip_vision_siglip_512.json @@ -0,0 +1,13 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 512, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 16, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} From 49b732afd54e1871d59fd0bca9e7a3a97e3532ea Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 6 Apr 2025 22:43:56 -0400 Subject: [PATCH 36/37] Show a proper error to the user when a vision model file is invalid. --- nodes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nodes.py b/nodes.py index 272c2a25..218f9325 100644 --- a/nodes.py +++ b/nodes.py @@ -1006,6 +1006,8 @@ class CLIPVisionLoader: def load_clip(self, clip_name): clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name) clip_vision = comfy.clip_vision.load(clip_path) + if clip_vision is None: + raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") return (clip_vision,) class CLIPVisionEncode: From 70d7242e57e853c489b608e88d7874e546474604 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 7 Apr 2025 05:01:47 -0400 Subject: [PATCH 37/37] Support the wan fun reward loras. --- comfy/lora_convert.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 05032c69..3e00b63d 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -1,4 +1,5 @@ import torch +import comfy.utils def convert_lora_bfl_control(sd): #BFL loras for Flux @@ -11,7 +12,13 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux return sd_out +def convert_lora_wan_fun(sd): #Wan Fun loras + return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"}) + + def convert_lora(sd): if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd: return convert_lora_bfl_control(sd) + if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd: + return convert_lora_wan_fun(sd) return sd