mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add Flux fp16 support hack.
This commit is contained in:
parent
6969fc9ba4
commit
8115d8cce9
@ -188,6 +188,10 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
|
if txt.dtype == torch.float16:
|
||||||
|
txt = txt.clip(-65504, 65504)
|
||||||
|
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
@ -239,7 +243,10 @@ class SingleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe)
|
attn = attention(q, k, v, pe=pe)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
return x + mod.gate * output
|
x = x + mod.gate * output
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
x = x.clip(-65504, 65504)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LastLayer(nn.Module):
|
class LastLayer(nn.Module):
|
||||||
|
@ -642,7 +642,7 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
memory_usage_factor = 2.8
|
memory_usage_factor = 2.8
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
Loading…
Reference in New Issue
Block a user