diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 77009c91..c98d4dfa 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -123,6 +123,7 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") +parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/model_base.py b/comfy/model_base.py index 830bcc68..9bfdb3b3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -96,10 +96,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - if self.manual_cast_dtype is not None: - operations = comfy.ops.manual_cast - else: - operations = comfy.ops.disable_weight_init + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) diff --git a/comfy/model_management.py b/comfy/model_management.py index 27a3c9d6..b899f8b8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1048,6 +1048,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False +def supports_fp8_compute(device=None): + props = torch.cuda.get_device_properties(device) + if props.major >= 9: + return True + if props.major < 8: + return False + if props.minor < 9: + return False + return True + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/ops.py b/comfy/ops.py index 418d59e1..fc78dd83 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -18,7 +18,7 @@ import torch import comfy.model_management - +from comfy.cli_args import args def cast_to(weight, dtype=None, device=None, non_blocking=False): if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device): @@ -242,3 +242,42 @@ class manual_cast(disable_weight_init): class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True + + +def fp8_linear(self, input): + dtype = self.weight.dtype + if dtype not in [torch.float8_e4m3fn]: + return None + + if len(input.shape) == 3: + out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype) + inn = input.to(dtype) + non_blocking = comfy.model_management.device_supports_non_blocking(input.device) + w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t() + for i in range(input.shape[0]): + if self.bias is not None: + o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking)) + else: + o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype) + out[i] = o + return out + return None + +class fp8_ops(manual_cast): + class Linear(manual_cast.Linear): + def forward_comfy_cast_weights(self, input): + out = fp8_linear(self, input) + if out is not None: + return out + + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + +def pick_operations(weight_dtype, compute_dtype, load_device=None): + if compute_dtype is None or weight_dtype == compute_dtype: + return disable_weight_init + if args.fast: + if comfy.model_management.supports_fp8_compute(load_device): + return fp8_ops + return manual_cast