mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add --fast argument to enable experimental optimizations.
Optimizations that might break things/lower quality will be put behind this flag first and might be enabled by default in the future. Currently the only optimization is float8_e4m3fn matrix multiplication on 4000/ADA series Nvidia cards or later. If you have one of these cards you will see a speed boost when using fp8_e4m3fn flux for example.
This commit is contained in:
parent
d1a6bd6845
commit
9953f22fce
@ -123,6 +123,7 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
|
|||||||
|
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||||
|
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
@ -96,10 +96,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
if self.manual_cast_dtype is not None:
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
||||||
operations = comfy.ops.manual_cast
|
|
||||||
else:
|
|
||||||
operations = comfy.ops.disable_weight_init
|
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
@ -1048,6 +1048,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def supports_fp8_compute(device=None):
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major >= 9:
|
||||||
|
return True
|
||||||
|
if props.major < 8:
|
||||||
|
return False
|
||||||
|
if props.minor < 9:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
41
comfy/ops.py
41
comfy/ops.py
@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
||||||
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
||||||
@ -242,3 +242,42 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
class Embedding(disable_weight_init.Embedding):
|
class Embedding(disable_weight_init.Embedding):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_linear(self, input):
|
||||||
|
dtype = self.weight.dtype
|
||||||
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(input.shape) == 3:
|
||||||
|
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
|
||||||
|
inn = input.to(dtype)
|
||||||
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||||
|
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
|
||||||
|
for i in range(input.shape[0]):
|
||||||
|
if self.bias is not None:
|
||||||
|
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
|
||||||
|
else:
|
||||||
|
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
|
||||||
|
out[i] = o
|
||||||
|
return out
|
||||||
|
return None
|
||||||
|
|
||||||
|
class fp8_ops(manual_cast):
|
||||||
|
class Linear(manual_cast.Linear):
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
out = fp8_linear(self, input)
|
||||||
|
if out is not None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
||||||
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
|
return disable_weight_init
|
||||||
|
if args.fast:
|
||||||
|
if comfy.model_management.supports_fp8_compute(load_device):
|
||||||
|
return fp8_ops
|
||||||
|
return manual_cast
|
||||||
|
Loading…
Reference in New Issue
Block a user