diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 23ddea9c0..c48576028 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -102,9 +102,9 @@ class CLIPTextModel_(torch.nn.Module): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(-torch.finfo(x.dtype).max).triu_(1) if mask is not None: mask += causal_mask else: diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index fc9bac1d2..d4edd5aa5 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -118,7 +118,7 @@ class BertModel_(torch.nn.Module): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) x, i = self.encoder(x, mask, intermediate_output) return x, i diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index 7405528e2..df2b5b5cd 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -203,7 +203,7 @@ class T5Stack(torch.nn.Module): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) intermediate = None optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)