mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 15:03:33 +00:00
CublasOps support
This commit is contained in:
parent
22ad513c72
commit
4faebf0fb2
23
comfy/ops.py
23
comfy/ops.py
@ -357,6 +357,25 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
|
|
||||||
return scaled_fp8_op
|
return scaled_fp8_op
|
||||||
|
|
||||||
|
CUBLAS_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from cublas_ops import CublasLinear
|
||||||
|
CUBLAS_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if CUBLAS_IS_AVAILABLE:
|
||||||
|
class cublas_ops(disable_weight_init):
|
||||||
|
class Linear(CublasLinear, disable_weight_init.Linear):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
@ -369,6 +388,10 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
|||||||
):
|
):
|
||||||
return fp8_ops
|
return fp8_ops
|
||||||
|
|
||||||
|
if CUBLAS_IS_AVAILABLE and weight_dtype == torch.float16 and (compute_dtype == torch.float16 or compute_dtype is None):
|
||||||
|
logging.info("Using cublas ops")
|
||||||
|
return cublas_ops
|
||||||
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user