diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 79ecbd682..81f29f098 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -136,8 +136,9 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u class PerformanceFeature(enum.Enum): Fp16Accumulation = "fp16_accumulation" Fp8MatrixMultiplication = "fp8_matrix_mult" + CublasOps = "cublas_ops" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult") +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/ops.py b/comfy/ops.py index 116b08fc2..9a5c1ee99 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -388,7 +388,12 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ ): return fp8_ops - if CUBLAS_IS_AVAILABLE and weight_dtype == torch.float16 and (compute_dtype == torch.float16 or compute_dtype is None): + if ( + PerformanceFeature.CublasOps in args.fast and + CUBLAS_IS_AVAILABLE and + weight_dtype == torch.float16 and + (compute_dtype == torch.float16 or compute_dtype is None) + ): logging.info("Using cublas ops") return cublas_ops