Lower the chances of an OOM.

This commit is contained in:
comfyanonymous 2023-02-08 14:24:27 -05:00
parent 853e96ada3
commit 047775615b

View File

@ -76,7 +76,8 @@ def _summarize_chunk(
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
torch.exp(attn_weights - max_score, out=attn_weights)
exp_weights = attn_weights
exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)