diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 53f27e3a7..09dd2482c 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -187,6 +187,9 @@ class Flux(nn.Module): if add is not None: img[:, txt.shape[1] :, ...] += add + if img.dtype == torch.float16: + img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) + img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)