You can now select the device index with: --directml id

Like this for example: --directml 1
This commit is contained in:
comfyanonymous 2023-04-28 16:51:35 -04:00
parent cab80973d1
commit 2ca934f7d4
2 changed files with 10 additions and 4 deletions

View File

@ -10,7 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", action="store_true", help="Use torch-directml.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")

View File

@ -21,10 +21,15 @@ accelerate_enabled = False
xpu_available = False
directml_enabled = False
if args.directml:
if args.directml is not None:
import torch_directml
print("Using directml")
directml_enabled = True
device_index = args.directml
if device_index < 0:
directml_device = torch_directml.device()
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True)
try:
@ -226,7 +231,8 @@ def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
return torch_directml.device()
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU: