This commit is contained in:
shawnington 2024-12-06 03:32:53 -08:00 committed by GitHub
commit e015462b06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -88,6 +88,7 @@ def Normalize(in_channels, dtype=None, device=None):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
cast_to_type = attn_precision if attn_precision is not None else q.dtype
if skip_reshape:
b, _, _, dim_head = q.shape
@ -113,12 +114,9 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
# force cast to fp32 to avoid overflowing if args.dont_upcast_attention is not set
sim = einsum('b i d, b j d -> b i j', q.to(dtype=cast_to_type), k.to(dtype=cast_to_type)) * scale
del q, k
if exists(mask):
@ -291,7 +289,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
end = i + slice_size
if upcast:
with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].to(dtype=torch.float32), k.to(dtype=torch.float32)) * scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
@ -344,6 +342,8 @@ except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
cast_to_type = attn_precision if attn_precision is not None else q.dtype
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -365,12 +365,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
lambda t: t.reshape(b * heads, -1, dim_head).to(dtype=cast_to_type),
(q, k, v)
)
else:
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
lambda t: t.reshape(b, -1, heads, dim_head).to(dtype=cast_to_type),
(q, k, v),
)
@ -404,13 +404,20 @@ else:
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
cast_to_type = attn_precision if attn_precision is not None else q.dtype
if skip_reshape:
b, _, _, dim_head = q.shape
q, k, v = map(
lambda t: t.to(dtype=cast_to_type), (q, k, v),
)
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2).to(dtype=cast_to_type),
(q, k, v),
)