From 008761166fdf90db95f7f757f6f995be8bded508 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 21:48:46 -0500 Subject: [PATCH] Optimize first attention block in cosmos VAE. --- comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py | 17 +++++---------- comfy/ldm/modules/diffusionmodules/model.py | 21 +++++++++++-------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py index 6149e53e..7d864a75 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py @@ -30,6 +30,8 @@ import torch.nn as nn import torch.nn.functional as F import logging +from comfy.ldm.modules.diffusionmodules.model import vae_attention + from .patching import ( Patcher, Patcher3D, @@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module): in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) + self.optimized_attention = vae_attention() + def forward(self, x: torch.Tensor) -> torch.Tensor: h_ = x h_ = self.norm(h_) @@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module): v, batch_size = time2batch(v) b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h * w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c) ** (-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) + h_ = self.optimized_attention(q, k, v) h_ = batch2time(h_, batch_size) h_ = self.proj_out(h_) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index ed1e8821..303147a9 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -293,6 +293,17 @@ def pytorch_attention(q, k, v): return out +def vae_attention(): + if model_management.xformers_enabled_vae(): + logging.info("Using xformers attention in VAE") + return xformers_attention + elif model_management.pytorch_attention_enabled(): + logging.info("Using pytorch attention in VAE") + return pytorch_attention + else: + logging.info("Using split attention in VAE") + return normal_attention + class AttnBlock(nn.Module): def __init__(self, in_channels, conv_op=ops.Conv2d): super().__init__() @@ -320,15 +331,7 @@ class AttnBlock(nn.Module): stride=1, padding=0) - if model_management.xformers_enabled_vae(): - logging.info("Using xformers attention in VAE") - self.optimized_attention = xformers_attention - elif model_management.pytorch_attention_enabled(): - logging.info("Using pytorch attention in VAE") - self.optimized_attention = pytorch_attention - else: - logging.info("Using split attention in VAE") - self.optimized_attention = normal_attention + self.optimized_attention = vae_attention() def forward(self, x): h_ = x