mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-13 14:21:20 +00:00
Merge branch 'master' of https://github.com/GaidamakUA/ComfyUI
This commit is contained in:
commit
3c6ff8821c
@ -4,6 +4,7 @@ NO_VRAM = 1
|
||||
LOW_VRAM = 2
|
||||
NORMAL_VRAM = 3
|
||||
HIGH_VRAM = 4
|
||||
MPS = 5
|
||||
|
||||
accelerate_enabled = False
|
||||
vram_state = NORMAL_VRAM
|
||||
@ -61,7 +62,8 @@ if "--novram" in sys.argv:
|
||||
set_vram_to = NO_VRAM
|
||||
if "--highvram" in sys.argv:
|
||||
vram_state = HIGH_VRAM
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
vram_state = MPS
|
||||
|
||||
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
|
||||
try:
|
||||
@ -79,7 +81,7 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
|
||||
if "--cpu" in sys.argv:
|
||||
vram_state = CPU
|
||||
|
||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state])
|
||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
@ -128,6 +130,10 @@ def load_model_gpu(model):
|
||||
current_loaded_model = model
|
||||
if vram_state == CPU:
|
||||
pass
|
||||
elif vram_state == MPS:
|
||||
mps_device = torch.device("mps")
|
||||
real_model.to(mps_device)
|
||||
pass
|
||||
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
|
||||
model_accelerated = False
|
||||
real_model.cuda()
|
||||
@ -146,6 +152,9 @@ def load_controlnet_gpu(models):
|
||||
global vram_state
|
||||
if vram_state == CPU:
|
||||
return
|
||||
|
||||
if vram_state == MPS:
|
||||
return
|
||||
|
||||
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
|
||||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||
@ -173,6 +182,8 @@ def unload_if_low_vram(model):
|
||||
return model
|
||||
|
||||
def get_torch_device():
|
||||
if vram_state == MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
@ -195,7 +206,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and dev.type == 'cpu':
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
@ -224,8 +235,12 @@ def cpu_mode():
|
||||
global vram_state
|
||||
return vram_state == CPU
|
||||
|
||||
def mps_mode():
|
||||
global vram_state
|
||||
return vram_state == MPS
|
||||
|
||||
def should_use_fp16():
|
||||
if cpu_mode():
|
||||
if cpu_mode() or mps_mode():
|
||||
return False #TODO ?
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
|
Loading…
Reference in New Issue
Block a user