diff --git a/comfy/model_base.py b/comfy/model_base.py index a304c58bd..2fa1ee911 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -108,7 +108,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None) + fp8 = model_config.optimizations.get("fp8", False) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) else: operations = model_config.custom_operations diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 1aef549f4..403da5855 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -471,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.scaled_fp8 = scaled_fp8_weight.dtype if model_config.scaled_fp8 == torch.float32: model_config.scaled_fp8 = torch.float8_e4m3fn + if scaled_fp8_weight.nelement() == 2: + model_config.optimizations["fp8"] = False + else: + model_config.optimizations["fp8"] = True return model_config diff --git a/comfy/ops.py b/comfy/ops.py index 358c6ec60..3303c6fcd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -17,6 +17,7 @@ """ import torch +import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float @@ -308,6 +309,7 @@ class fp8_ops(manual_cast): return torch.nn.functional.linear(input, weight, bias) def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): + logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) class scaled_fp8_op(manual_cast): class Linear(manual_cast.Linear): def __init__(self, *args, **kwargs): @@ -358,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8) + return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=True, override_dtype=scaled_fp8) if ( fp8_compute and