mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 16:43:36 +00:00
Lower cosmos diffusion model memory usage.
This commit is contained in:
parent
4758fb64b9
commit
25683b5b02
@ -168,14 +168,18 @@ class Attention(nn.Module):
|
||||
k = self.to_k[1](k)
|
||||
v = self.to_v[1](v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
q = apply_rotary_pos_emb(q, rope_emb)
|
||||
k = apply_rotary_pos_emb(k, rope_emb)
|
||||
return q, k, v
|
||||
# apply_rotary_pos_emb inlined
|
||||
q_shape = q.shape
|
||||
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
||||
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
||||
|
||||
def cal_attn(self, q, k, v, mask=None):
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
# apply_rotary_pos_emb inlined
|
||||
k_shape = k.shape
|
||||
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
||||
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -191,7 +195,10 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||
return self.cal_attn(q, k, v, mask)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
del q, k, v
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x = x + extra_per_block_pos_emb
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
x,
|
||||
|
@ -168,7 +168,7 @@ class GeneralDIT(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.build_pos_embed(device=device)
|
||||
self.build_pos_embed(device=device, dtype=dtype)
|
||||
self.block_x_format = block_x_format
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
@ -210,7 +210,7 @@ class GeneralDIT(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def build_pos_embed(self, device=None):
|
||||
def build_pos_embed(self, device=None, dtype=None):
|
||||
if self.pos_emb_cls == "rope3d":
|
||||
cls_type = VideoRopePosition3DEmb
|
||||
else:
|
||||
@ -242,6 +242,7 @@ class GeneralDIT(nn.Module):
|
||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||
kwargs["device"] = device
|
||||
kwargs["dtype"] = dtype
|
||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||
**kwargs,
|
||||
)
|
||||
@ -476,6 +477,8 @@ class GeneralDIT(nn.Module):
|
||||
inputs["original_shape"],
|
||||
)
|
||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||
del inputs
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
assert (
|
||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
@ -486,6 +489,8 @@ class GeneralDIT(nn.Module):
|
||||
self.blocks["block0"].x_format == block.x_format
|
||||
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
||||
x = block(
|
||||
x,
|
||||
affline_emb_B_D,
|
||||
@ -493,7 +498,6 @@ class GeneralDIT(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
)
|
||||
|
||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||
|
@ -173,6 +173,7 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
||||
len_w: int,
|
||||
len_t: int,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -184,9 +185,9 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
||||
self.interpolation = interpolation
|
||||
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||
|
||||
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
|
||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
|
||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
|
||||
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||
|
||||
|
||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
||||
|
Loading…
Reference in New Issue
Block a user