diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 8d86aa53..3e12886b 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -104,9 +104,7 @@ def attention_basic(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 0bcda84f..7e293ef6 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -27,9 +27,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale