diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 0163c6fe7..cf5b58b62 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -104,10 +104,7 @@ class CLIPTextModel_(torch.nn.Module): mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) - if comfy.model_management.is_directml_enabled(): - causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1) - else: - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1) if mask is not None: mask += causal_mask