mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add argument to skip the output reshaping in the attention functions.
This commit is contained in:
parent
ff838657fa
commit
129d8908f7
@ -89,7 +89,7 @@ class FeedForward(nn.Module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -142,16 +142,23 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
sim = sim.softmax(dim=-1)
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||||
out = (
|
|
||||||
out.unsqueeze(0)
|
if skip_output_reshape:
|
||||||
.reshape(b, heads, -1, dim_head)
|
out = (
|
||||||
.permute(0, 2, 1, 3)
|
out.unsqueeze(0)
|
||||||
.reshape(b, -1, heads * dim_head)
|
.reshape(b, heads, -1, dim_head)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
if skip_output_reshape:
|
||||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
hidden_states = hidden_states.unflatten(0, (-1, heads))
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -326,12 +335,18 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r1 = (
|
if skip_output_reshape:
|
||||||
r1.unsqueeze(0)
|
r1 = (
|
||||||
.reshape(b, heads, -1, dim_head)
|
r1.unsqueeze(0)
|
||||||
.permute(0, 2, 1, 3)
|
.reshape(b, heads, -1, dim_head)
|
||||||
.reshape(b, -1, heads * dim_head)
|
)
|
||||||
)
|
else:
|
||||||
|
r1 = (
|
||||||
|
r1.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
return r1
|
return r1
|
||||||
|
|
||||||
BROKEN_XFORMERS = False
|
BROKEN_XFORMERS = False
|
||||||
@ -342,7 +357,7 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
b = q.shape[0]
|
b = q.shape[0]
|
||||||
dim_head = q.shape[-1]
|
dim_head = q.shape[-1]
|
||||||
# check to make sure xformers isn't broken
|
# check to make sure xformers isn't broken
|
||||||
@ -395,9 +410,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
|
||||||
out = (
|
if skip_output_reshape:
|
||||||
out.reshape(b, -1, heads * dim_head)
|
out = out.permute(0, 2, 1, 3)
|
||||||
)
|
else:
|
||||||
|
out = (
|
||||||
|
out.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -408,7 +426,7 @@ else:
|
|||||||
SDP_BATCH_LIMIT = 2**31
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
@ -429,9 +447,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
out = (
|
if not skip_output_reshape:
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out = (
|
||||||
)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||||
for i in range(0, b, SDP_BATCH_LIMIT):
|
for i in range(0, b, SDP_BATCH_LIMIT):
|
||||||
@ -450,7 +469,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout="HND"
|
tensor_layout="HND"
|
||||||
@ -473,11 +492,15 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
|
|
||||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
if tensor_layout == "HND":
|
if tensor_layout == "HND":
|
||||||
out = (
|
if not skip_output_reshape:
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out = (
|
||||||
)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
out = out.reshape(b, -1, heads * dim_head)
|
if skip_output_reshape:
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
else:
|
||||||
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user