Fix overflow issue with inplace softmax.

This commit is contained in:
comfyanonymous 2023-02-10 11:47:41 -05:00
parent 509c7dfc6d
commit 1a4edd19cd

View File

@ -158,6 +158,7 @@ def _get_attention_scores_no_kv_chunking(
del attn_scores
except OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
attn_scores /= summed