Add argument to skip the output reshaping in the attention functions.

This commit is contained in:
comfyanonymous 2025-01-10 06:27:37 -05:00
parent ff838657fa
commit 129d8908f7

View File

@ -89,7 +89,7 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision)
if skip_reshape: if skip_reshape:
@ -142,6 +142,13 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
sim = sim.softmax(dim=-1) sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
if skip_output_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
out = ( out = (
out.unsqueeze(0) out.unsqueeze(0)
.reshape(b, heads, -1, dim_head) .reshape(b, heads, -1, dim_head)
@ -151,7 +158,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return out return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision)
if skip_reshape: if skip_reshape:
@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
) )
hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.to(dtype)
if skip_output_reshape:
hidden_states = hidden_states.unflatten(0, (-1, heads))
else:
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision)
if skip_reshape: if skip_reshape:
@ -326,6 +335,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
del q, k, v del q, k, v
if skip_output_reshape:
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
r1 = ( r1 = (
r1.unsqueeze(0) r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head) .reshape(b, heads, -1, dim_head)
@ -342,7 +357,7 @@ try:
except: except:
pass pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
b = q.shape[0] b = q.shape[0]
dim_head = q.shape[-1] dim_head = q.shape[-1]
# check to make sure xformers isn't broken # check to make sure xformers isn't broken
@ -395,6 +410,9 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)
else:
out = ( out = (
out.reshape(b, -1, heads * dim_head) out.reshape(b, -1, heads * dim_head)
) )
@ -408,7 +426,7 @@ else:
SDP_BATCH_LIMIT = 2**31 SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
else: else:
@ -429,6 +447,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if SDP_BATCH_LIMIT >= b: if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = ( out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
) )
@ -450,7 +469,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout="HND" tensor_layout="HND"
@ -473,9 +492,13 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND": if tensor_layout == "HND":
if not skip_output_reshape:
out = ( out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
) )
else:
if skip_output_reshape:
out = out.transpose(1, 2)
else: else:
out = out.reshape(b, -1, heads * dim_head) out = out.reshape(b, -1, heads * dim_head)
return out return out