mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 12:23:30 +00:00
Guard CublasOps behind --fast arg
This commit is contained in:
parent
4faebf0fb2
commit
b151877371
@ -136,8 +136,9 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u
|
|||||||
class PerformanceFeature(enum.Enum):
|
class PerformanceFeature(enum.Enum):
|
||||||
Fp16Accumulation = "fp16_accumulation"
|
Fp16Accumulation = "fp16_accumulation"
|
||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
|
CublasOps = "cublas_ops"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
@ -388,7 +388,12 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
|||||||
):
|
):
|
||||||
return fp8_ops
|
return fp8_ops
|
||||||
|
|
||||||
if CUBLAS_IS_AVAILABLE and weight_dtype == torch.float16 and (compute_dtype == torch.float16 or compute_dtype is None):
|
if (
|
||||||
|
PerformanceFeature.CublasOps in args.fast and
|
||||||
|
CUBLAS_IS_AVAILABLE and
|
||||||
|
weight_dtype == torch.float16 and
|
||||||
|
(compute_dtype == torch.float16 or compute_dtype is None)
|
||||||
|
):
|
||||||
logging.info("Using cublas ops")
|
logging.info("Using cublas ops")
|
||||||
return cublas_ops
|
return cublas_ops
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user