mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Load controlnet in fp8 if weights are in fp8.
This commit is contained in:
parent
2d810b081e
commit
dc96a1ae19
@ -400,19 +400,22 @@ class ControlLora(ControlNet):
|
|||||||
def controlnet_config(sd, model_options={}):
|
def controlnet_config(sd, model_options={}):
|
||||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||||
|
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
controlnet_config = model_config.unet_config
|
|
||||||
unet_dtype = model_options.get("dtype", comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes))
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
if operations is None:
|
if operations is None:
|
||||||
if manual_cast_dtype is not None:
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
operations = comfy.ops.manual_cast
|
|
||||||
else:
|
|
||||||
operations = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
offload_device = comfy.model_management.unet_offload_device()
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
@ -583,22 +586,30 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
controlnet_config = model_config.unet_config
|
controlnet_config = model_config.unet_config
|
||||||
|
|
||||||
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||||
|
|
||||||
|
if supported_inference_dtypes is None:
|
||||||
|
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||||
|
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
if supported_inference_dtypes is None:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype()
|
|
||||||
else:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
operations = model_options.get("custom_operations", None)
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
if operations is None:
|
||||||
if "custom_operations" in model_options:
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||||
controlnet_config["operations"] = model_options["custom_operations"]
|
|
||||||
if "dtype" in model_options:
|
controlnet_config["operations"] = operations
|
||||||
controlnet_config["dtype"] = model_options["dtype"]
|
controlnet_config["dtype"] = unet_dtype
|
||||||
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
|
@ -626,6 +626,8 @@ def maximum_vram_for_weights(device=None):
|
|||||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
|
if model_params < 0:
|
||||||
|
model_params = 1000000000000000000000
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if args.fp16_unet:
|
if args.fp16_unet:
|
||||||
|
@ -300,10 +300,10 @@ class fp8_ops(manual_cast):
|
|||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False):
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
if args.fast:
|
if args.fast and not disable_fast_fp8:
|
||||||
if comfy.model_management.supports_fp8_compute(load_device):
|
if comfy.model_management.supports_fp8_compute(load_device):
|
||||||
return fp8_ops
|
return fp8_ops
|
||||||
return manual_cast
|
return manual_cast
|
||||||
|
Loading…
Reference in New Issue
Block a user