Added MPS device support

This commit is contained in:
Yurii Mazurevich 2023-03-24 14:04:50 +02:00
parent dd095efc2c
commit 89fd5ed574

View File

@ -4,6 +4,7 @@ NO_VRAM = 1
LOW_VRAM = 2 LOW_VRAM = 2
NORMAL_VRAM = 3 NORMAL_VRAM = 3
HIGH_VRAM = 4 HIGH_VRAM = 4
MPS = 4
accelerate_enabled = False accelerate_enabled = False
vram_state = NORMAL_VRAM vram_state = NORMAL_VRAM
@ -61,7 +62,8 @@ if "--novram" in sys.argv:
set_vram_to = NO_VRAM set_vram_to = NO_VRAM
if "--highvram" in sys.argv: if "--highvram" in sys.argv:
vram_state = HIGH_VRAM vram_state = HIGH_VRAM
if torch.backends.mps.is_available():
vram_state = MPS
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
@ -79,7 +81,7 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
if "--cpu" in sys.argv: if "--cpu" in sys.argv:
vram_state = CPU vram_state = CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state]) print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
current_loaded_model = None current_loaded_model = None
@ -128,6 +130,12 @@ def load_model_gpu(model):
current_loaded_model = model current_loaded_model = model
if vram_state == CPU: if vram_state == CPU:
pass pass
elif vram_state == MPS:
# print(inspect.getmro(real_model.__class__))
# print(dir(real_model))
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
model_accelerated = False model_accelerated = False
real_model.cuda() real_model.cuda()
@ -146,6 +154,9 @@ def load_controlnet_gpu(models):
global vram_state global vram_state
if vram_state == CPU: if vram_state == CPU:
return return
if vram_state == MPS:
return
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == LOW_VRAM or vram_state == NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
@ -173,6 +184,8 @@ def unload_if_low_vram(model):
return model return model
def get_torch_device(): def get_torch_device():
if vram_state == MPS:
return torch.device("mps")
if vram_state == CPU: if vram_state == CPU:
return torch.device("cpu") return torch.device("cpu")
else: else:
@ -195,7 +208,7 @@ def get_free_memory(dev=None, torch_free_too=False):
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
if hasattr(dev, 'type') and dev.type == 'cpu': if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
@ -224,8 +237,12 @@ def cpu_mode():
global vram_state global vram_state
return vram_state == CPU return vram_state == CPU
def mps_mode():
global vram_state
return vram_state == MPS
def should_use_fp16(): def should_use_fp16():
if cpu_mode(): if cpu_mode() or mps_mode():
return False #TODO ? return False #TODO ?
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():