diff --git a/comfy/ops.py b/comfy/ops.py index ced461011..116b08fc2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -357,6 +357,25 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return scaled_fp8_op +CUBLAS_IS_AVAILABLE = False +try: + from cublas_ops import CublasLinear + CUBLAS_IS_AVAILABLE = True +except ImportError: + pass + +if CUBLAS_IS_AVAILABLE: + class cublas_ops(disable_weight_init): + class Linear(CublasLinear, disable_weight_init.Linear): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + return super().forward(input) + + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: @@ -369,6 +388,10 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ ): return fp8_ops + if CUBLAS_IS_AVAILABLE and weight_dtype == torch.float16 and (compute_dtype == torch.float16 or compute_dtype is None): + logging.info("Using cublas ops") + return cublas_ops + if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init