mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
You can now select the device index with: --directml id
Like this for example: --directml 1
This commit is contained in:
parent
cab80973d1
commit
2ca934f7d4
@ -10,7 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
parser.add_argument("--directml", action="store_true", help="Use torch-directml.")
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||
|
@ -21,10 +21,15 @@ accelerate_enabled = False
|
||||
xpu_available = False
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml:
|
||||
if args.directml is not None:
|
||||
import torch_directml
|
||||
print("Using directml")
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
if device_index < 0:
|
||||
directml_device = torch_directml.device()
|
||||
else:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
|
||||
try:
|
||||
@ -226,7 +231,8 @@ def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
return torch_directml.device()
|
||||
global directml_device
|
||||
return directml_device
|
||||
if vram_state == VRAMState.MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == VRAMState.CPU:
|
||||
|
Loading…
Reference in New Issue
Block a user