mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-19 22:17:08 +08:00
add args: npu-device
This commit is contained in:
parent
35504e2f93
commit
f4208a28fd
@ -50,6 +50,7 @@ parser.add_argument("--input-directory", type=str, default=None, help="Set the C
|
|||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
|
parser.add_argument("--npu-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the Ascend npu device this instance will use.")
|
||||||
cm_group = parser.add_mutually_exclusive_group()
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||||
|
@ -125,9 +125,20 @@ def is_mlu():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def set_npu_device(device_id):
|
||||||
|
"""Set the NPU device to use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_id (int): The id of the NPU device to use.
|
||||||
|
"""
|
||||||
|
if device_id is not None:
|
||||||
|
torch.npu.set_device(device_id)
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
|
if is_ascend_npu():
|
||||||
|
return torch.device("npu", torch.npu.current_device())
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
global directml_device
|
global directml_device
|
||||||
return directml_device
|
return directml_device
|
||||||
@ -138,8 +149,6 @@ def get_torch_device():
|
|||||||
else:
|
else:
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return torch.device("xpu", torch.xpu.current_device())
|
return torch.device("xpu", torch.xpu.current_device())
|
||||||
elif is_ascend_npu():
|
|
||||||
return torch.device("npu", torch.npu.current_device())
|
|
||||||
elif is_mlu():
|
elif is_mlu():
|
||||||
return torch.device("mlu", torch.mlu.current_device())
|
return torch.device("mlu", torch.mlu.current_device())
|
||||||
else:
|
else:
|
||||||
|
6
main.py
6
main.py
@ -6,6 +6,7 @@ import importlib.util
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.model_management
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
@ -114,6 +115,10 @@ if __name__ == "__main__":
|
|||||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||||
|
|
||||||
|
if args.npu_device is not None:
|
||||||
|
comfy.model_management.set_npu_device(args.npu_device)
|
||||||
|
logging.info("Set npu device to: {}".format(args.npu_device))
|
||||||
|
|
||||||
if args.oneapi_device_selector is not None:
|
if args.oneapi_device_selector is not None:
|
||||||
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
||||||
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
||||||
@ -137,7 +142,6 @@ import execution
|
|||||||
import server
|
import server
|
||||||
from server import BinaryEventTypes
|
from server import BinaryEventTypes
|
||||||
import nodes
|
import nodes
|
||||||
import comfy.model_management
|
|
||||||
import comfyui_version
|
import comfyui_version
|
||||||
import app.logger
|
import app.logger
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user