diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 3e9c6497..84fd6d83 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -168,14 +168,18 @@ class Attention(nn.Module): k = self.to_k[1](k) v = self.to_v[1](v) if self.is_selfattn and rope_emb is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb) - k = apply_rotary_pos_emb(k, rope_emb) - return q, k, v + # apply_rotary_pos_emb inlined + q_shape = q.shape + q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2) + q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1] + q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype) - def cal_attn(self, q, k, v, mask=None): - out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) - out = rearrange(out, " b n s c -> s b (n c)") - return self.to_out(out) + # apply_rotary_pos_emb inlined + k_shape = k.shape + k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2) + k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1] + k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype) + return q, k, v def forward( self, @@ -191,7 +195,10 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - return self.cal_attn(q, k, v, mask) + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + del q, k, v + out = rearrange(out, " b n s c -> s b (n c)") + return self.to_out(out) class FeedForward(nn.Module): @@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if extra_per_block_pos_emb is not None: - x = x + extra_per_block_pos_emb for block in self.blocks: x = block( x, diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py index 6149e53e..9a3ebed6 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py @@ -30,6 +30,8 @@ import torch.nn as nn import torch.nn.functional as F import logging +from comfy.ldm.modules.diffusionmodules.model import vae_attention + from .patching import ( Patcher, Patcher3D, @@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module): in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) + self.optimized_attention = vae_attention() + def forward(self, x: torch.Tensor) -> torch.Tensor: h_ = x h_ = self.norm(h_) @@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module): v, batch_size = time2batch(v) b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h * w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c) ** (-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) + h_ = self.optimized_attention(q, k, v) h_ = batch2time(h_, batch_size) h_ = self.proj_out(h_) @@ -871,18 +864,16 @@ class EncoderFactorized(nn.Module): x = self.patcher3d(x) # downsampling - hs = [self.conv_in(x)] + h = self.conv_in(x) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1]) + h = self.down[i_level].block[i_block](h) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) - hs.append(h) if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) + h = self.down[i_level].downsample(h) # middle - h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/patching.py b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py index 793f0da8..87a53a1d 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/patching.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py @@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher): hh = hh.to(dtype=dtype) xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + del x # Height height transposed convolutions. xll = F.conv_transpose3d( xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xlll + xll += F.conv_transpose3d( xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xllh xlh = F.conv_transpose3d( xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xlhl + xlh += F.conv_transpose3d( xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xlhh xhl = F.conv_transpose3d( xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhll + xhl += F.conv_transpose3d( xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhlh xhh = F.conv_transpose3d( xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhhl + xhh += F.conv_transpose3d( xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhhh # Handles width transposed convolutions. xl = F.conv_transpose3d( xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xll + xl += F.conv_transpose3d( xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xlh + xh = F.conv_transpose3d( xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xhl + xh += F.conv_transpose3d( xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xhh # Handles time axis transposed convolutions. x = F.conv_transpose3d( xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) ) + del xl + x += F.conv_transpose3d( xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) ) diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 05dd3846..06d0baef 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -168,7 +168,7 @@ class GeneralDIT(nn.Module): operations=operations, ) - self.build_pos_embed(device=device) + self.build_pos_embed(device=device, dtype=dtype) self.block_x_format = block_x_format self.use_adaln_lora = use_adaln_lora self.adaln_lora_dim = adaln_lora_dim @@ -210,7 +210,7 @@ class GeneralDIT(nn.Module): operations=operations, ) - def build_pos_embed(self, device=None): + def build_pos_embed(self, device=None, dtype=None): if self.pos_emb_cls == "rope3d": cls_type = VideoRopePosition3DEmb else: @@ -242,6 +242,7 @@ class GeneralDIT(nn.Module): kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio kwargs["device"] = device + kwargs["dtype"] = dtype self.extra_pos_embedder = LearnablePosEmbAxis( **kwargs, ) @@ -292,7 +293,7 @@ class GeneralDIT(nn.Module): x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) if self.extra_per_block_abs_pos_emb: - extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) else: extra_pos_emb = None @@ -476,6 +477,8 @@ class GeneralDIT(nn.Module): inputs["original_shape"], ) extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype) + del inputs + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: assert ( x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape @@ -486,6 +489,8 @@ class GeneralDIT(nn.Module): self.blocks["block0"].x_format == block.x_format ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D x = block( x, affline_emb_B_D, @@ -493,7 +498,6 @@ class GeneralDIT(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, ) x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py index dda752cb..4d6a58db 100644 --- a/comfy/ldm/cosmos/position_embedding.py +++ b/comfy/ldm/cosmos/position_embedding.py @@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) class VideoPositionEmb(nn.Module): - def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: """ It delegates the embedding generation to generate_embeddings function. """ B_T_H_W_C = x_B_T_H_W_C.shape - embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype) return embeddings @@ -104,6 +104,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb): w_ntk_factor: Optional[float] = None, t_ntk_factor: Optional[float] = None, device=None, + dtype=None, ): """ Generate embeddings for the given input size. @@ -173,6 +174,7 @@ class LearnablePosEmbAxis(VideoPositionEmb): len_w: int, len_t: int, device=None, + dtype=None, **kwargs, ): """ @@ -184,17 +186,16 @@ class LearnablePosEmbAxis(VideoPositionEmb): self.interpolation = interpolation assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" - self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device)) - self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device)) - self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device)) + self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype)) + self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype)) + self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype)) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: B, T, H, W, _ = B_T_H_W_C if self.interpolation == "crop": - emb_h_H = self.pos_emb_h[:H].to(device=device) - emb_w_W = self.pos_emb_w[:W].to(device=device) - emb_t_T = self.pos_emb_t[:T].to(device=device) + emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype) + emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype) + emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype) emb = ( repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) diff --git a/comfy/ldm/cosmos/vae.py b/comfy/ldm/cosmos/vae.py index 94fcc54c..d64f292d 100644 --- a/comfy/ldm/cosmos/vae.py +++ b/comfy/ldm/cosmos/vae.py @@ -18,6 +18,7 @@ import logging import torch from torch import nn from enum import Enum +import math from .cosmos_tokenizer.layers3d import ( EncoderFactorized, @@ -89,8 +90,8 @@ class CausalContinuousVideoTokenizer(nn.Module): self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value() num_parameters = sum(param.numel() for param in self.parameters()) - logging.info(f"model={self.name}, num_parameters={num_parameters:,}") - logging.info( + logging.debug(f"model={self.name}, num_parameters={num_parameters:,}") + logging.debug( f"z_channels={z_channels}, latent_channels={self.latent_channels}." ) @@ -105,17 +106,23 @@ class CausalContinuousVideoTokenizer(nn.Module): z, posteriors = self.distribution(moments) latent_ch = z.shape[1] latent_t = z.shape[2] - dtype = z.dtype - mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) - std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) + in_dtype = z.dtype + mean = self.latent_mean.view(latent_ch, -1) + std = self.latent_std.view(latent_ch, -1) + + mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) + std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) return ((z - mean) / std) * self.sigma_data def decode(self, z): in_dtype = z.dtype latent_ch = z.shape[1] latent_t = z.shape[2] - mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) - std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) + mean = self.latent_mean.view(latent_ch, -1) + std = self.latent_std.view(latent_ch, -1) + + mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) + std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) z = z / self.sigma_data z = z * std + mean diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8e055151..59a62e0d 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module): def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: mod, _ = self.modulation(vec) - x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index b6549585..b5960ffd 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -5,8 +5,15 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management + def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: - q, k = apply_rope(q, k, pe) + q_shape = q.shape + k_shape = k.shape + + q = q.float().reshape(*q.shape[:-1], -1, 1, 2) + k = k.float().reshape(*k.shape[:-1], -1, 1, 2) + q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) + k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index ed1e8821..303147a9 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -293,6 +293,17 @@ def pytorch_attention(q, k, v): return out +def vae_attention(): + if model_management.xformers_enabled_vae(): + logging.info("Using xformers attention in VAE") + return xformers_attention + elif model_management.pytorch_attention_enabled(): + logging.info("Using pytorch attention in VAE") + return pytorch_attention + else: + logging.info("Using split attention in VAE") + return normal_attention + class AttnBlock(nn.Module): def __init__(self, in_channels, conv_op=ops.Conv2d): super().__init__() @@ -320,15 +331,7 @@ class AttnBlock(nn.Module): stride=1, padding=0) - if model_management.xformers_enabled_vae(): - logging.info("Using xformers attention in VAE") - self.optimized_attention = xformers_attention - elif model_management.pytorch_attention_enabled(): - logging.info("Using pytorch attention in VAE") - self.optimized_attention = pytorch_attention - else: - logging.info("Using split attention in VAE") - self.optimized_attention = normal_attention + self.optimized_attention = vae_attention() def forward(self, x): h_ = x diff --git a/comfy/sd.py b/comfy/sd.py index 6ba6af47..d7e89f72 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -388,8 +388,8 @@ class VAE: ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8} self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig) #TODO: these values are a bit off because this is not a standard VAE - self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ff3f1432..ff0bea41 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -788,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.HunyuanVideo - memory_usage_factor = 2.0 #TODO + memory_usage_factor = 1.8 #TODO supported_inference_dtypes = [torch.bfloat16, torch.float32] @@ -839,7 +839,7 @@ class CosmosT2V(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Cosmos1CV8x8x8 - memory_usage_factor = 2.4 #TODO + memory_usage_factor = 1.6 #TODO supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO diff --git a/comfyui_version.py b/comfyui_version.py index 7cccc753..411243f6 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.10" +__version__ = "0.3.12" diff --git a/pyproject.toml b/pyproject.toml index b747d6ef..0198d1b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.10" +version = "0.3.12" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9"