From 2326ff1263632068dd4c98ddaec1a24418834c25 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 17 Feb 2023 21:14:07 -0500 Subject: [PATCH] Add: --highvram for when you want models to stay on the vram. --- comfy/model_management.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 1301f746..8c859d3f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -3,6 +3,7 @@ CPU = 0 NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 +HIGH_VRAM = 4 accelerate_enabled = False vram_state = NORMAL_VRAM @@ -27,10 +28,11 @@ if "--lowvram" in sys.argv: set_vram_to = LOW_VRAM if "--novram" in sys.argv: set_vram_to = NO_VRAM +if "--highvram" in sys.argv: + vram_state = HIGH_VRAM - -if set_vram_to != NORMAL_VRAM: +if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: try: import accelerate accelerate_enabled = True @@ -44,7 +46,7 @@ if set_vram_to != NORMAL_VRAM: total_vram_available_mb = int(max(256, total_vram_available_mb)) -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state]) +print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state]) current_loaded_model = None @@ -57,18 +59,24 @@ def unload_model(): global current_loaded_model global model_accelerated global current_gpu_controlnets + global vram_state + if current_loaded_model is not None: if model_accelerated: accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) model_accelerated = False - current_loaded_model.model.cpu() + #never unload models from GPU on high vram + if vram_state != HIGH_VRAM: + current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None - if len(current_gpu_controlnets) > 0: - for n in current_gpu_controlnets: - n.cpu() - current_gpu_controlnets = [] + + if vram_state != HIGH_VRAM: + if len(current_gpu_controlnets) > 0: + for n in current_gpu_controlnets: + n.cpu() + current_gpu_controlnets = [] def load_model_gpu(model): @@ -87,7 +95,7 @@ def load_model_gpu(model): current_loaded_model = model if vram_state == CPU: pass - elif vram_state == NORMAL_VRAM: + elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: model_accelerated = False real_model.cuda() else: