diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 759788a9..b085bbc0 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -5,7 +5,7 @@ from typing import Dict, Optional import numpy as np import torch import torch.nn as nn -from .. import attention +from ..attention import optimized_attention from einops import rearrange, repeat from .util import timestep_embedding import comfy.ops @@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim): qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0) return qkv[0], qkv[1], qkv[2] -def optimized_attention(qkv, num_heads): - return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads) class SelfAttention(nn.Module): ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") @@ -326,9 +324,9 @@ class SelfAttention(nn.Module): return x def forward(self, x: torch.Tensor) -> torch.Tensor: - qkv = self.pre_attention(x) + q, k, v = self.pre_attention(x) x = optimized_attention( - qkv, num_heads=self.num_heads + q, k, v, heads=self.num_heads ) x = self.post_attention(x) return x @@ -531,8 +529,8 @@ class DismantledBlock(nn.Module): assert not self.pre_only qkv, intermediates = self.pre_attention(x, c) attn = optimized_attention( - qkv, - num_heads=self.attn.num_heads, + qkv[0], qkv[1], qkv[2], + heads=self.attn.num_heads, ) return self.post_attention(attn, *intermediates) @@ -557,8 +555,8 @@ def _block_mixing(context, x, context_block, x_block, c): qkv = tuple(o) attn = optimized_attention( - qkv, - num_heads=x_block.attn.num_heads, + qkv[0], qkv[1], qkv[2], + heads=x_block.attn.num_heads, ) context_attn, x_attn = ( attn[:, : context_qkv[0].shape[1]], @@ -642,7 +640,7 @@ class SelfAttentionContext(nn.Module): def forward(self, x): qkv = self.qkv(x) q, k, v = split_qkv(qkv, self.dim_head) - x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads) + x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads) return self.proj(x) class ContextProcessorBlock(nn.Module):