From 25683b5b0269590ba24f96753cf55cc6ad093cd0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 23:46:42 -0500 Subject: [PATCH] Lower cosmos diffusion model memory usage. --- comfy/ldm/cosmos/blocks.py | 26 +++++++++++++++----------- comfy/ldm/cosmos/model.py | 10 +++++++--- comfy/ldm/cosmos/position_embedding.py | 7 ++++--- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 3e9c6497..84fd6d83 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -168,14 +168,18 @@ class Attention(nn.Module): k = self.to_k[1](k) v = self.to_v[1](v) if self.is_selfattn and rope_emb is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb) - k = apply_rotary_pos_emb(k, rope_emb) - return q, k, v + # apply_rotary_pos_emb inlined + q_shape = q.shape + q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2) + q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1] + q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype) - def cal_attn(self, q, k, v, mask=None): - out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) - out = rearrange(out, " b n s c -> s b (n c)") - return self.to_out(out) + # apply_rotary_pos_emb inlined + k_shape = k.shape + k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2) + k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1] + k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype) + return q, k, v def forward( self, @@ -191,7 +195,10 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - return self.cal_attn(q, k, v, mask) + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + del q, k, v + out = rearrange(out, " b n s c -> s b (n c)") + return self.to_out(out) class FeedForward(nn.Module): @@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if extra_per_block_pos_emb is not None: - x = x + extra_per_block_pos_emb for block in self.blocks: x = block( x, diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 05dd3846..1205838b 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -168,7 +168,7 @@ class GeneralDIT(nn.Module): operations=operations, ) - self.build_pos_embed(device=device) + self.build_pos_embed(device=device, dtype=dtype) self.block_x_format = block_x_format self.use_adaln_lora = use_adaln_lora self.adaln_lora_dim = adaln_lora_dim @@ -210,7 +210,7 @@ class GeneralDIT(nn.Module): operations=operations, ) - def build_pos_embed(self, device=None): + def build_pos_embed(self, device=None, dtype=None): if self.pos_emb_cls == "rope3d": cls_type = VideoRopePosition3DEmb else: @@ -242,6 +242,7 @@ class GeneralDIT(nn.Module): kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio kwargs["device"] = device + kwargs["dtype"] = dtype self.extra_pos_embedder = LearnablePosEmbAxis( **kwargs, ) @@ -476,6 +477,8 @@ class GeneralDIT(nn.Module): inputs["original_shape"], ) extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype) + del inputs + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: assert ( x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape @@ -486,6 +489,8 @@ class GeneralDIT(nn.Module): self.blocks["block0"].x_format == block.x_format ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D x = block( x, affline_emb_B_D, @@ -493,7 +498,6 @@ class GeneralDIT(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, ) x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py index dda752cb..cf45ab0e 100644 --- a/comfy/ldm/cosmos/position_embedding.py +++ b/comfy/ldm/cosmos/position_embedding.py @@ -173,6 +173,7 @@ class LearnablePosEmbAxis(VideoPositionEmb): len_w: int, len_t: int, device=None, + dtype=None, **kwargs, ): """ @@ -184,9 +185,9 @@ class LearnablePosEmbAxis(VideoPositionEmb): self.interpolation = interpolation assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" - self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device)) - self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device)) - self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device)) + self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype)) + self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype)) + self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype)) def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: