mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Merge 01abd4b9b4
into 8af9a91e0c
This commit is contained in:
commit
e015462b06
@ -88,6 +88,7 @@ def Normalize(in_channels, dtype=None, device=None):
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
cast_to_type = attn_precision if attn_precision is not None else q.dtype
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
@ -113,12 +114,9 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if attn_precision == torch.float32:
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
# force cast to fp32 to avoid overflowing if args.dont_upcast_attention is not set
|
||||
sim = einsum('b i d, b j d -> b i j', q.to(dtype=cast_to_type), k.to(dtype=cast_to_type)) * scale
|
||||
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
@ -291,7 +289,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
end = i + slice_size
|
||||
if upcast:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].to(dtype=torch.float32), k.to(dtype=torch.float32)) * scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||
|
||||
@ -344,6 +342,8 @@ except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
cast_to_type = attn_precision if attn_precision is not None else q.dtype
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
@ -365,12 +365,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
lambda t: t.reshape(b * heads, -1, dim_head).to(dtype=cast_to_type),
|
||||
(q, k, v)
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||
lambda t: t.reshape(b, -1, heads, dim_head).to(dtype=cast_to_type),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
@ -404,13 +404,20 @@ else:
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
cast_to_type = attn_precision if attn_precision is not None else q.dtype
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.to(dtype=cast_to_type), (q, k, v),
|
||||
)
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2).to(dtype=cast_to_type),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user