Add --fast argument to enable experimental optimizations.

Optimizations that might break things/lower quality will be put behind
this flag first and might be enabled by default in the future.

Currently the only optimization is float8_e4m3fn matrix multiplication on
4000/ADA series Nvidia cards or later. If you have one of these cards you
will see a speed boost when using fp8_e4m3fn flux for example.
This commit is contained in:
comfyanonymous 2024-08-20 11:49:33 -04:00
parent d1a6bd6845
commit 9953f22fce
4 changed files with 52 additions and 5 deletions

View File

@ -123,6 +123,7 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

View File

@ -96,10 +96,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
if self.manual_cast_dtype is not None: operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)

View File

@ -1048,6 +1048,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
def supports_fp8_compute(device=None):
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True
def soft_empty_cache(force=False): def soft_empty_cache(force=False):
global cpu_state global cpu_state
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:

View File

@ -18,7 +18,7 @@
import torch import torch
import comfy.model_management import comfy.model_management
from comfy.cli_args import args
def cast_to(weight, dtype=None, device=None, non_blocking=False): def cast_to(weight, dtype=None, device=None, non_blocking=False):
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device): if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
@ -242,3 +242,42 @@ class manual_cast(disable_weight_init):
class Embedding(disable_weight_init.Embedding): class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True comfy_cast_weights = True
def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
if len(input.shape) == 3:
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
inn = input.to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
for i in range(input.shape[0]):
if self.bias is not None:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
else:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
out[i] = o
return out
return None
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast