mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 08:33:29 +00:00
Optimize first attention block in cosmos VAE.
This commit is contained in:
parent
bfd5dfd611
commit
008761166f
@ -30,6 +30,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import logging
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
|
||||
from .patching import (
|
||||
Patcher,
|
||||
Patcher3D,
|
||||
@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module):
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module):
|
||||
v, batch_size = time2batch(v)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h * w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
h_ = self.optimized_attention(q, k, v)
|
||||
|
||||
h_ = batch2time(h_, batch_size)
|
||||
h_ = self.proj_out(h_)
|
||||
|
@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
|
||||
return out
|
||||
|
||||
|
||||
def vae_attention():
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
return xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
return pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
return normal_attention
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
self.optimized_attention = xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
self.optimized_attention = pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
self.optimized_attention = normal_attention
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
|
Loading…
Reference in New Issue
Block a user