From 8115d8cce97a3edaaad8b08b45ab37c6782e1cb4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 15:08:39 -0400 Subject: [PATCH] Add Flux fp16 support hack. --- comfy/ldm/flux/layers.py | 9 ++++++++- comfy/supported_models.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 99f49810..4a0bd40c 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -188,6 +188,10 @@ class DoubleStreamBlock(nn.Module): # calculate the txt bloks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + + if txt.dtype == torch.float16: + txt = txt.clip(-65504, 65504) + return img, txt @@ -239,7 +243,10 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - return x + mod.gate * output + x = x + mod.gate * output + if x.dtype == torch.float16: + x = x.clip(-65504, 65504) + return x class LastLayer(nn.Module): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6cecb9a0..d07a7106 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -642,7 +642,7 @@ class Flux(supported_models_base.BASE): memory_usage_factor = 2.8 - supported_inference_dtypes = [torch.bfloat16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."]