From 0075c6d0968fd0a8ee8426f6cadbd40d287f15a0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 21 Oct 2024 18:12:51 -0400 Subject: [PATCH] Mixed precision diffusion models with scaled fp8. This change allows supports for diffusion models where all the linears are scaled fp8 while the other weights are the original precision. --- comfy/model_base.py | 6 +++--- comfy/model_detection.py | 7 +++++-- comfy/ops.py | 6 +++--- comfy/sd.py | 4 ++-- comfy/supported_models_base.py | 2 +- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7e86e76d..5138d2b9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8) + fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) else: operations = model_config.custom_operations @@ -246,8 +246,8 @@ class BaseModel(torch.nn.Module): unet_state_dict = self.diffusion_model.state_dict() - if self.model_config.scaled_fp8: - unet_state_dict["scaled_fp8"] = torch.tensor([]) + if self.model_config.scaled_fp8 is not None: + unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3f720bce..e1d29db3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -288,8 +288,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) - if "{}scaled_fp8".format(unet_key_prefix) in state_dict: - model_config.scaled_fp8 = True + scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None) + if scaled_fp8_weight is not None: + model_config.scaled_fp8 = scaled_fp8_weight.dtype + if model_config.scaled_fp8 == torch.float32: + model_config.scaled_fp8 = torch.float8_e4m3fn return model_config diff --git a/comfy/ops.py b/comfy/ops.py index 05f7d306..2890cac0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -334,10 +334,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return scaled_fp8_op -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) - if scaled_fp8: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True) + if scaled_fp8 is not None: + return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8) if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8: return fp8_ops diff --git a/comfy/sd.py b/comfy/sd.py index bcec48c0..e4abf0b9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -579,7 +579,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return None unet_weight_dtype = list(model_config.supported_inference_dtypes) - if weight_dtype is not None: + if weight_dtype is not None and model_config.scaled_fp8 is None: unet_weight_dtype.append(weight_dtype) model_config.custom_operations = model_options.get("custom_operations", None) @@ -677,7 +677,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if weight_dtype is not None: + if weight_dtype is not None and model_config.scaled_fp8 is None: unet_weight_dtype.append(weight_dtype) if dtype is None: diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 0e69d16a..54573abb 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,7 +49,7 @@ class BASE: manual_cast_dtype = None custom_operations = None - scaled_fp8 = False + scaled_fp8 = None optimizations = {"fp8": False} @classmethod