From 129d8908f7e21c09c9f47954ab6d0539473fa982 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 10 Jan 2025 06:27:37 -0500 Subject: [PATCH] Add argument to skip the output reshaping in the attention functions. --- comfy/ldm/modules/attention.py | 83 ++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 0d54e6be..44aec59a 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -89,7 +89,7 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): +def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): attn_precision = get_attn_precision(attn_precision) if skip_reshape: @@ -142,16 +142,23 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) - out = ( - out.unsqueeze(0) - .reshape(b, heads, -1, dim_head) - .permute(0, 2, 1, 3) - .reshape(b, -1, heads * dim_head) - ) + + if skip_output_reshape: + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + ) + else: + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) return out -def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False): +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): attn_precision = get_attn_precision(attn_precision) if skip_reshape: @@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, ) hidden_states = hidden_states.to(dtype) - - hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) + if skip_output_reshape: + hidden_states = hidden_states.unflatten(0, (-1, heads)) + else: + hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states -def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): +def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): attn_precision = get_attn_precision(attn_precision) if skip_reshape: @@ -326,12 +335,18 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape del q, k, v - r1 = ( - r1.unsqueeze(0) - .reshape(b, heads, -1, dim_head) - .permute(0, 2, 1, 3) - .reshape(b, -1, heads * dim_head) - ) + if skip_output_reshape: + r1 = ( + r1.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + ) + else: + r1 = ( + r1.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) return r1 BROKEN_XFORMERS = False @@ -342,7 +357,7 @@ try: except: pass -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -395,9 +410,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - out = ( - out.reshape(b, -1, heads * dim_head) - ) + if skip_output_reshape: + out = out.permute(0, 2, 1, 3) + else: + out = ( + out.reshape(b, -1, heads * dim_head) + ) return out @@ -408,7 +426,7 @@ else: SDP_BATCH_LIMIT = 2**31 -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -429,9 +447,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if SDP_BATCH_LIMIT >= b: out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - out = ( - out.transpose(1, 2).reshape(b, -1, heads * dim_head) - ) + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) else: out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device) for i in range(0, b, SDP_BATCH_LIMIT): @@ -450,7 +469,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape tensor_layout="HND" @@ -473,11 +492,15 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) if tensor_layout == "HND": - out = ( - out.transpose(1, 2).reshape(b, -1, heads * dim_head) - ) + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) else: - out = out.reshape(b, -1, heads * dim_head) + if skip_output_reshape: + out = out.transpose(1, 2) + else: + out = out.reshape(b, -1, heads * dim_head) return out