mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 12:23:30 +00:00
Merge b151877371
into 22ad513c72
This commit is contained in:
commit
c516de4a2e
@ -136,8 +136,9 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u
|
|||||||
class PerformanceFeature(enum.Enum):
|
class PerformanceFeature(enum.Enum):
|
||||||
Fp16Accumulation = "fp16_accumulation"
|
Fp16Accumulation = "fp16_accumulation"
|
||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
|
CublasOps = "cublas_ops"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
28
comfy/ops.py
28
comfy/ops.py
@ -357,6 +357,25 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
|
|
||||||
return scaled_fp8_op
|
return scaled_fp8_op
|
||||||
|
|
||||||
|
CUBLAS_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from cublas_ops import CublasLinear
|
||||||
|
CUBLAS_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if CUBLAS_IS_AVAILABLE:
|
||||||
|
class cublas_ops(disable_weight_init):
|
||||||
|
class Linear(CublasLinear, disable_weight_init.Linear):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
@ -369,6 +388,15 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
|||||||
):
|
):
|
||||||
return fp8_ops
|
return fp8_ops
|
||||||
|
|
||||||
|
if (
|
||||||
|
PerformanceFeature.CublasOps in args.fast and
|
||||||
|
CUBLAS_IS_AVAILABLE and
|
||||||
|
weight_dtype == torch.float16 and
|
||||||
|
(compute_dtype == torch.float16 or compute_dtype is None)
|
||||||
|
):
|
||||||
|
logging.info("Using cublas ops")
|
||||||
|
return cublas_ops
|
||||||
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user