mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
UNET weights can now be stored in fp8.
--fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats supported by pytorch.
This commit is contained in:
parent
af365e4dd1
commit
31b0f6f3d8
@ -283,7 +283,7 @@ class ControlNet(nn.Module):
|
||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
|
||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
@ -295,7 +295,7 @@ class ControlNet(nn.Module):
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = x
|
||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||
if guided_hint is not None:
|
||||
h = module(h, emb, context)
|
||||
|
@ -55,7 +55,10 @@ fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
fpunet_group = parser.add_mutually_exclusive_group()
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import contextlib
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
@ -147,24 +148,31 @@ class ControlNet(ControlBase):
|
||||
else:
|
||||
return None
|
||||
|
||||
dtype = self.control_model.dtype
|
||||
if comfy.model_management.supports_dtype(self.device, dtype):
|
||||
precision_scope = lambda a: contextlib.nullcontext(a)
|
||||
else:
|
||||
precision_scope = torch.autocast
|
||||
dtype = torch.float32
|
||||
|
||||
output_dtype = x_noisy.dtype
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
|
||||
context = cond['c_crossattn']
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(self.control_model.dtype)
|
||||
y = y.to(dtype)
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
||||
with precision_scope(comfy.model_management.get_autocast_device(self.device)):
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
|
@ -841,14 +841,14 @@ class UNetModel(nn.Module):
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = x
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
|
@ -5,6 +5,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
from enum import Enum
|
||||
import contextlib
|
||||
from . import utils
|
||||
|
||||
class ModelType(Enum):
|
||||
@ -61,6 +62,13 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if comfy.model_management.supports_dtype(xc.device, dtype):
|
||||
precision_scope = lambda a: contextlib.nullcontext(a)
|
||||
else:
|
||||
precision_scope = torch.autocast
|
||||
dtype = torch.float32
|
||||
|
||||
xc = xc.to(dtype)
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
context = context.to(dtype)
|
||||
@ -70,7 +78,10 @@ class BaseModel(torch.nn.Module):
|
||||
if hasattr(extra, "to"):
|
||||
extra = extra.to(dtype)
|
||||
extra_conds[o] = extra
|
||||
|
||||
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def get_dtype(self):
|
||||
|
@ -459,6 +459,10 @@ def unet_inital_load_device(parameters, dtype):
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp8_e4m3fn_unet:
|
||||
return torch.float8_e4m3fn
|
||||
if args.fp8_e5m2_unet:
|
||||
return torch.float8_e5m2
|
||||
if should_use_fp16(device=device, model_params=model_params):
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
@ -515,6 +519,17 @@ def get_autocast_device(dev):
|
||||
return dev.type
|
||||
return "cuda"
|
||||
|
||||
def supports_dtype(device, dtype): #TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
if torch.device("cpu") == device:
|
||||
return False
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
return False
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
|
Loading…
Reference in New Issue
Block a user