Merge branch 'master' into multigpu_support

This commit is contained in:
Jedrzej Kosinski 2025-01-16 18:25:05 -06:00
commit 31f5458938
13 changed files with 105 additions and 67 deletions

View File

@ -168,14 +168,18 @@ class Attention(nn.Module):
k = self.to_k[1](k) k = self.to_k[1](k)
v = self.to_v[1](v) v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention! if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb) # apply_rotary_pos_emb inlined
k = apply_rotary_pos_emb(k, rope_emb) q_shape = q.shape
return q, k, v q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
def cal_attn(self, q, k, v, mask=None): # apply_rotary_pos_emb inlined
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) k_shape = k.shape
out = rearrange(out, " b n s c -> s b (n c)") k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
return self.to_out(out) k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
return q, k, v
def forward( def forward(
self, self,
@ -191,7 +195,10 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
""" """
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
return self.cal_attn(q, k, v, mask) out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
class FeedForward(nn.Module): class FeedForward(nn.Module):
@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x = x + extra_per_block_pos_emb
for block in self.blocks: for block in self.blocks:
x = block( x = block(
x, x,

View File

@ -30,6 +30,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import logging import logging
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from .patching import ( from .patching import (
Patcher, Patcher,
Patcher3D, Patcher3D,
@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module):
in_channels, in_channels, kernel_size=1, stride=1, padding=0 in_channels, in_channels, kernel_size=1, stride=1, padding=0
) )
self.optimized_attention = vae_attention()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module):
v, batch_size = time2batch(v) v, batch_size = time2batch(v)
b, c, h, w = q.shape b, c, h, w = q.shape
q = q.reshape(b, c, h * w) h_ = self.optimized_attention(q, k, v)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = batch2time(h_, batch_size) h_ = batch2time(h_, batch_size)
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
@ -871,18 +864,16 @@ class EncoderFactorized(nn.Module):
x = self.patcher3d(x) x = self.patcher3d(x)
# downsampling # downsampling
hs = [self.conv_in(x)] h = self.conv_in(x)
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1]) h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1])) h = self.down[i_level].downsample(h)
# middle # middle
h = hs[-1]
h = self.mid.block_1(h) h = self.mid.block_1(h)
h = self.mid.attn_1(h) h = self.mid.attn_1(h)
h = self.mid.block_2(h) h = self.mid.block_2(h)

View File

@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher):
hh = hh.to(dtype=dtype) hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
del x
# Height height transposed convolutions. # Height height transposed convolutions.
xll = F.conv_transpose3d( xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xlll
xll += F.conv_transpose3d( xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xllh
xlh = F.conv_transpose3d( xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xlhl
xlh += F.conv_transpose3d( xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xlhh
xhl = F.conv_transpose3d( xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xhll
xhl += F.conv_transpose3d( xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xhlh
xhh = F.conv_transpose3d( xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xhhl
xhh += F.conv_transpose3d( xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xhhh
# Handles width transposed convolutions. # Handles width transposed convolutions.
xl = F.conv_transpose3d( xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
) )
del xll
xl += F.conv_transpose3d( xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
) )
del xlh
xh = F.conv_transpose3d( xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
) )
del xhl
xh += F.conv_transpose3d( xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
) )
del xhh
# Handles time axis transposed convolutions. # Handles time axis transposed convolutions.
x = F.conv_transpose3d( x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
) )
del xl
x += F.conv_transpose3d( x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
) )

View File

@ -168,7 +168,7 @@ class GeneralDIT(nn.Module):
operations=operations, operations=operations,
) )
self.build_pos_embed(device=device) self.build_pos_embed(device=device, dtype=dtype)
self.block_x_format = block_x_format self.block_x_format = block_x_format
self.use_adaln_lora = use_adaln_lora self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim self.adaln_lora_dim = adaln_lora_dim
@ -210,7 +210,7 @@ class GeneralDIT(nn.Module):
operations=operations, operations=operations,
) )
def build_pos_embed(self, device=None): def build_pos_embed(self, device=None, dtype=None):
if self.pos_emb_cls == "rope3d": if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb cls_type = VideoRopePosition3DEmb
else: else:
@ -242,6 +242,7 @@ class GeneralDIT(nn.Module):
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis( self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs, **kwargs,
) )
@ -292,7 +293,7 @@ class GeneralDIT(nn.Module):
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb: if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
else: else:
extra_pos_emb = None extra_pos_emb = None
@ -476,6 +477,8 @@ class GeneralDIT(nn.Module):
inputs["original_shape"], inputs["original_shape"],
) )
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype) extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
del inputs
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert ( assert (
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
@ -486,6 +489,8 @@ class GeneralDIT(nn.Module):
self.blocks["block0"].x_format == block.x_format self.blocks["block0"].x_format == block.x_format
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
x = block( x = block(
x, x,
affline_emb_B_D, affline_emb_B_D,
@ -493,7 +498,6 @@ class GeneralDIT(nn.Module):
crossattn_mask, crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D, rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D, adaln_lora_B_3D=adaln_lora_B_3D,
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
) )
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

View File

@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0)
class VideoPositionEmb(nn.Module): class VideoPositionEmb(nn.Module):
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
""" """
It delegates the embedding generation to generate_embeddings function. It delegates the embedding generation to generate_embeddings function.
""" """
B_T_H_W_C = x_B_T_H_W_C.shape B_T_H_W_C = x_B_T_H_W_C.shape
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device) embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
return embeddings return embeddings
@ -104,6 +104,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
w_ntk_factor: Optional[float] = None, w_ntk_factor: Optional[float] = None,
t_ntk_factor: Optional[float] = None, t_ntk_factor: Optional[float] = None,
device=None, device=None,
dtype=None,
): ):
""" """
Generate embeddings for the given input size. Generate embeddings for the given input size.
@ -173,6 +174,7 @@ class LearnablePosEmbAxis(VideoPositionEmb):
len_w: int, len_w: int,
len_t: int, len_t: int,
device=None, device=None,
dtype=None,
**kwargs, **kwargs,
): ):
""" """
@ -184,17 +186,16 @@ class LearnablePosEmbAxis(VideoPositionEmb):
self.interpolation = interpolation self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device)) self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device)) self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device)) self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
B, T, H, W, _ = B_T_H_W_C B, T, H, W, _ = B_T_H_W_C
if self.interpolation == "crop": if self.interpolation == "crop":
emb_h_H = self.pos_emb_h[:H].to(device=device) emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
emb_w_W = self.pos_emb_w[:W].to(device=device) emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
emb_t_T = self.pos_emb_t[:T].to(device=device) emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
emb = ( emb = (
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)

View File

@ -18,6 +18,7 @@ import logging
import torch import torch
from torch import nn from torch import nn
from enum import Enum from enum import Enum
import math
from .cosmos_tokenizer.layers3d import ( from .cosmos_tokenizer.layers3d import (
EncoderFactorized, EncoderFactorized,
@ -89,8 +90,8 @@ class CausalContinuousVideoTokenizer(nn.Module):
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value() self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
num_parameters = sum(param.numel() for param in self.parameters()) num_parameters = sum(param.numel() for param in self.parameters())
logging.info(f"model={self.name}, num_parameters={num_parameters:,}") logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
logging.info( logging.debug(
f"z_channels={z_channels}, latent_channels={self.latent_channels}." f"z_channels={z_channels}, latent_channels={self.latent_channels}."
) )
@ -105,17 +106,23 @@ class CausalContinuousVideoTokenizer(nn.Module):
z, posteriors = self.distribution(moments) z, posteriors = self.distribution(moments)
latent_ch = z.shape[1] latent_ch = z.shape[1]
latent_t = z.shape[2] latent_t = z.shape[2]
dtype = z.dtype in_dtype = z.dtype
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) mean = self.latent_mean.view(latent_ch, -1)
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) std = self.latent_std.view(latent_ch, -1)
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
return ((z - mean) / std) * self.sigma_data return ((z - mean) / std) * self.sigma_data
def decode(self, z): def decode(self, z):
in_dtype = z.dtype in_dtype = z.dtype
latent_ch = z.shape[1] latent_ch = z.shape[1]
latent_t = z.shape[2] latent_t = z.shape[2]
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) mean = self.latent_mean.view(latent_ch, -1)
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) std = self.latent_std.view(latent_ch, -1)
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
z = z / self.sigma_data z = z / self.sigma_data
z = z * std + mean z = z * std + mean

View File

@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec) mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)

View File

@ -5,8 +5,15 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe) q_shape = q.shape
k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1] heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)

View File

@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
return out return out
def vae_attention():
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
return pytorch_attention
else:
logging.info("Using split attention in VAE")
return normal_attention
class AttnBlock(nn.Module): class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d): def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__() super().__init__()
@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
stride=1, stride=1,
padding=0) padding=0)
if model_management.xformers_enabled_vae(): self.optimized_attention = vae_attention()
logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention
def forward(self, x): def forward(self, x):
h_ = x h_ = x

View File

@ -388,8 +388,8 @@ class VAE:
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8} ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig) self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
#TODO: these values are a bit off because this is not a standard VAE #TODO: these values are a bit off because this is not a standard VAE
self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")

View File

@ -788,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 2.0 #TODO memory_usage_factor = 1.8 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -839,7 +839,7 @@ class CosmosT2V(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.Cosmos1CV8x8x8 latent_format = latent_formats.Cosmos1CV8x8x8
memory_usage_factor = 2.4 #TODO memory_usage_factor = 1.6 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.10" __version__ = "0.3.12"

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.10" version = "0.3.12"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"