This commit is contained in:
comfyanonymous 2025-02-11 17:17:03 -05:00
parent b124256817
commit d9f0fcdb0c

View File

@ -104,10 +104,7 @@ class CLIPTextModel_(torch.nn.Module):
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
if comfy.model_management.is_directml_enabled():
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
else:
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
if mask is not None:
mask += causal_mask