diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 229fe499d..b98820d83 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -321,8 +321,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) - scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None) - if scaled_fp8_weight is not None: + scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix) + if scaled_fp8_key in state_dict: + scaled_fp8_weight = state_dict.pop(scaled_fp8_key) model_config.scaled_fp8 = scaled_fp8_weight.dtype if model_config.scaled_fp8 == torch.float32: model_config.scaled_fp8 = torch.float8_e4m3fn