diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ff4385b33..9dfd69977 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -400,19 +400,22 @@ class ControlLora(ControlNet): def controlnet_config(sd, model_options={}): model_config = comfy.model_detection.model_config_from_unet(sd, "", True) - supported_inference_dtypes = model_config.supported_inference_dtypes + unet_dtype = model_options.get("dtype", None) + if unet_dtype is None: + weight_dtype = comfy.utils.weight_dtype(sd) + + supported_inference_dtypes = list(model_config.supported_inference_dtypes) + if weight_dtype is not None: + supported_inference_dtypes.append(weight_dtype) + + unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) - controlnet_config = model_config.unet_config - unet_dtype = model_options.get("dtype", comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)) load_device = comfy.model_management.get_torch_device() manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) operations = model_options.get("custom_operations", None) if operations is None: - if manual_cast_dtype is not None: - operations = comfy.ops.manual_cast - else: - operations = comfy.ops.disable_weight_init + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True) offload_device = comfy.model_management.unet_offload_device() return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device @@ -583,22 +586,30 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): if controlnet_config is None: model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) - supported_inference_dtypes = model_config.supported_inference_dtypes + supported_inference_dtypes = list(model_config.supported_inference_dtypes) controlnet_config = model_config.unet_config + unet_dtype = model_options.get("dtype", None) + if unet_dtype is None: + weight_dtype = comfy.utils.weight_dtype(controlnet_data) + + if supported_inference_dtypes is None: + supported_inference_dtypes = [comfy.model_management.unet_dtype()] + + if weight_dtype is not None: + supported_inference_dtypes.append(weight_dtype) + + unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) + load_device = comfy.model_management.get_torch_device() - if supported_inference_dtypes is None: - unet_dtype = comfy.model_management.unet_dtype() - else: - unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) - if manual_cast_dtype is not None: - controlnet_config["operations"] = comfy.ops.manual_cast - if "custom_operations" in model_options: - controlnet_config["operations"] = model_options["custom_operations"] - if "dtype" in model_options: - controlnet_config["dtype"] = model_options["dtype"] + operations = model_options.get("custom_operations", None) + if operations is None: + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype) + + controlnet_config["operations"] = operations + controlnet_config["dtype"] = unet_dtype controlnet_config["device"] = comfy.model_management.unet_offload_device() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] diff --git a/comfy/model_management.py b/comfy/model_management.py index 22a584a2e..a97d489d5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -626,6 +626,8 @@ def maximum_vram_for_weights(device=None): return (get_total_memory(device) * 0.88 - minimum_inference_memory()) def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): + if model_params < 0: + model_params = 1000000000000000000000 if args.bf16_unet: return torch.bfloat16 if args.fp16_unet: diff --git a/comfy/ops.py b/comfy/ops.py index 43ed55adb..1b386dba7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -300,10 +300,10 @@ class fp8_ops(manual_cast): return torch.nn.functional.linear(input, weight, bias) -def pick_operations(weight_dtype, compute_dtype, load_device=None): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False): if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init - if args.fast: + if args.fast and not disable_fast_fp8: if comfy.model_management.supports_fp8_compute(load_device): return fp8_ops return manual_cast