From 4faebf0fb2408a64d2f38032a5d0b9babb81d1cb Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 11 Apr 2025 18:20:15 -0400 Subject: [PATCH 1/2] CublasOps support --- comfy/ops.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index ced461011..116b08fc2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -357,6 +357,25 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return scaled_fp8_op +CUBLAS_IS_AVAILABLE = False +try: + from cublas_ops import CublasLinear + CUBLAS_IS_AVAILABLE = True +except ImportError: + pass + +if CUBLAS_IS_AVAILABLE: + class cublas_ops(disable_weight_init): + class Linear(CublasLinear, disable_weight_init.Linear): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + return super().forward(input) + + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: @@ -369,6 +388,10 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ ): return fp8_ops + if CUBLAS_IS_AVAILABLE and weight_dtype == torch.float16 and (compute_dtype == torch.float16 or compute_dtype is None): + logging.info("Using cublas ops") + return cublas_ops + if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init From b15187737198407f542fe5c9088a557c733b7759 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 11 Apr 2025 18:36:50 -0400 Subject: [PATCH 2/2] Guard CublasOps behind --fast arg --- comfy/cli_args.py | 3 ++- comfy/ops.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 79ecbd682..81f29f098 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -136,8 +136,9 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u class PerformanceFeature(enum.Enum): Fp16Accumulation = "fp16_accumulation" Fp8MatrixMultiplication = "fp8_matrix_mult" + CublasOps = "cublas_ops" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult") +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/ops.py b/comfy/ops.py index 116b08fc2..9a5c1ee99 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -388,7 +388,12 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ ): return fp8_ops - if CUBLAS_IS_AVAILABLE and weight_dtype == torch.float16 and (compute_dtype == torch.float16 or compute_dtype is None): + if ( + PerformanceFeature.CublasOps in args.fast and + CUBLAS_IS_AVAILABLE and + weight_dtype == torch.float16 and + (compute_dtype == torch.float16 or compute_dtype is None) + ): logging.info("Using cublas ops") return cublas_ops