Mixed precision diffusion models with scaled fp8.

This change allows supports for diffusion models where all the linears are
scaled fp8 while the other weights are the original precision.
This commit is contained in:
comfyanonymous 2024-10-21 18:12:51 -04:00
parent 83ca891118
commit 0075c6d096
5 changed files with 14 additions and 11 deletions

View File

@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8) fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
@ -246,8 +246,8 @@ class BaseModel(torch.nn.Module):
unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.diffusion_model.state_dict()
if self.model_config.scaled_fp8: if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([]) unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)

View File

@ -288,8 +288,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config) model_config = comfy.supported_models_base.BASE(unet_config)
if "{}scaled_fp8".format(unet_key_prefix) in state_dict: scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
model_config.scaled_fp8 = True if scaled_fp8_weight is not None:
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
return model_config return model_config

View File

@ -334,10 +334,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
return scaled_fp8_op return scaled_fp8_op
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8: if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops return fp8_ops

View File

@ -579,7 +579,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None return None
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None: if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype) unet_weight_dtype.append(weight_dtype)
model_config.custom_operations = model_options.get("custom_operations", None) model_config.custom_operations = model_options.get("custom_operations", None)
@ -677,7 +677,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None: if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype) unet_weight_dtype.append(weight_dtype)
if dtype is None: if dtype is None:

View File

@ -49,7 +49,7 @@ class BASE:
manual_cast_dtype = None manual_cast_dtype = None
custom_operations = None custom_operations = None
scaled_fp8 = False scaled_fp8 = None
optimizations = {"fp8": False} optimizations = {"fp8": False}
@classmethod @classmethod