mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
seperates out arg parser and imports args
This commit is contained in:
parent
dd29966f8a
commit
e5e587b1c0
29
comfy/cli_args.py
Normal file
29
comfy/cli_args.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 127.0.0.1 if none given so the UI can be accessed from other computers.")
|
||||||
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
|
parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.")
|
||||||
|
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||||
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
|
|
||||||
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||||
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
|
||||||
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.")
|
||||||
|
|
||||||
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
|
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||||
|
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||||
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
@ -21,6 +21,8 @@ if model_management.xformers_enabled():
|
|||||||
import os
|
import os
|
||||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||||
|
|
||||||
|
from cli_args import args
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
@ -474,7 +476,6 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
import sys
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
print("Using xformers cross attention")
|
print("Using xformers cross attention")
|
||||||
CrossAttention = MemoryEfficientCrossAttention
|
CrossAttention = MemoryEfficientCrossAttention
|
||||||
@ -482,7 +483,7 @@ elif model_management.pytorch_attention_enabled():
|
|||||||
print("Using pytorch cross attention")
|
print("Using pytorch cross attention")
|
||||||
CrossAttention = CrossAttentionPytorch
|
CrossAttention = CrossAttentionPytorch
|
||||||
else:
|
else:
|
||||||
if "--use-split-cross-attention" in sys.argv:
|
if args.use_split_cross_attention:
|
||||||
print("Using split optimization for cross attention")
|
print("Using split optimization for cross attention")
|
||||||
CrossAttention = CrossAttentionDoggettx
|
CrossAttention = CrossAttentionDoggettx
|
||||||
else:
|
else:
|
||||||
|
@ -1,36 +1,35 @@
|
|||||||
|
import psutil
|
||||||
|
from enum import Enum
|
||||||
|
from cli_args import args
|
||||||
|
|
||||||
CPU = 0
|
class VRAMState(Enum):
|
||||||
NO_VRAM = 1
|
CPU = 0
|
||||||
LOW_VRAM = 2
|
NO_VRAM = 1
|
||||||
NORMAL_VRAM = 3
|
LOW_VRAM = 2
|
||||||
HIGH_VRAM = 4
|
NORMAL_VRAM = 3
|
||||||
MPS = 5
|
HIGH_VRAM = 4
|
||||||
|
MPS = 5
|
||||||
|
|
||||||
accelerate_enabled = False
|
# Determine VRAM State
|
||||||
vram_state = NORMAL_VRAM
|
vram_state = VRAMState.NORMAL_VRAM
|
||||||
|
set_vram_to = VRAMState.NORMAL_VRAM
|
||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
total_vram_available_mb = -1
|
total_vram_available_mb = -1
|
||||||
|
|
||||||
import sys
|
accelerate_enabled = False
|
||||||
import psutil
|
|
||||||
|
|
||||||
forced_cpu = "--cpu" in sys.argv
|
|
||||||
|
|
||||||
set_vram_to = NORMAL_VRAM
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
forced_normal_vram = "--normalvram" in sys.argv
|
if not args.normalvram and not args.cpu:
|
||||||
if not forced_normal_vram and not forced_cpu:
|
|
||||||
if total_vram <= 4096:
|
if total_vram <= 4096:
|
||||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||||
set_vram_to = LOW_VRAM
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||||
vram_state = HIGH_VRAM
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -39,34 +38,32 @@ try:
|
|||||||
except:
|
except:
|
||||||
OOM_EXCEPTION = Exception
|
OOM_EXCEPTION = Exception
|
||||||
|
|
||||||
if "--disable-xformers" in sys.argv:
|
if args.disable_xformers:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
XFORMERS_IS_AVAILBLE = True
|
XFORMERS_IS_AVAILABLE = True
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = False
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||||
if "--use-pytorch-cross-attention" in sys.argv:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
XFORMERS_IS_AVAILABLE = False
|
||||||
XFORMERS_IS_AVAILBLE = False
|
|
||||||
|
if args.lowvram:
|
||||||
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
|
elif args.novram:
|
||||||
|
set_vram_to = VRAMState.NO_VRAM
|
||||||
|
elif args.highvram:
|
||||||
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
|
|
||||||
|
|
||||||
if "--lowvram" in sys.argv:
|
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||||
set_vram_to = LOW_VRAM
|
|
||||||
if "--novram" in sys.argv:
|
|
||||||
set_vram_to = NO_VRAM
|
|
||||||
if "--highvram" in sys.argv:
|
|
||||||
vram_state = HIGH_VRAM
|
|
||||||
|
|
||||||
|
|
||||||
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
|
|
||||||
try:
|
try:
|
||||||
import accelerate
|
import accelerate
|
||||||
accelerate_enabled = True
|
accelerate_enabled = True
|
||||||
@ -81,14 +78,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
vram_state = MPS
|
vram_state = VRAMState.MPS
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if forced_cpu:
|
if args.cpu:
|
||||||
vram_state = CPU
|
vram_state = VRAMState.CPU
|
||||||
|
|
||||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
|
print(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
|
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
@ -109,12 +106,12 @@ def unload_model():
|
|||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
|
||||||
#never unload models from GPU on high vram
|
#never unload models from GPU on high vram
|
||||||
if vram_state != HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
current_loaded_model.model.cpu()
|
current_loaded_model.model.cpu()
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
|
||||||
if vram_state != HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
if len(current_gpu_controlnets) > 0:
|
if len(current_gpu_controlnets) > 0:
|
||||||
for n in current_gpu_controlnets:
|
for n in current_gpu_controlnets:
|
||||||
n.cpu()
|
n.cpu()
|
||||||
@ -135,19 +132,19 @@ def load_model_gpu(model):
|
|||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
if vram_state == CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
pass
|
pass
|
||||||
elif vram_state == MPS:
|
elif vram_state == VRAMState.MPS:
|
||||||
mps_device = torch.device("mps")
|
mps_device = torch.device("mps")
|
||||||
real_model.to(mps_device)
|
real_model.to(mps_device)
|
||||||
pass
|
pass
|
||||||
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
|
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.cuda()
|
real_model.cuda()
|
||||||
else:
|
else:
|
||||||
if vram_state == NO_VRAM:
|
if vram_state == VRAMState.NO_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
elif vram_state == LOW_VRAM:
|
elif vram_state == VRAMState.LOW_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
||||||
|
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
|
||||||
@ -157,10 +154,10 @@ def load_model_gpu(model):
|
|||||||
def load_controlnet_gpu(models):
|
def load_controlnet_gpu(models):
|
||||||
global current_gpu_controlnets
|
global current_gpu_controlnets
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
return
|
return
|
||||||
|
|
||||||
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -176,20 +173,20 @@ def load_controlnet_gpu(models):
|
|||||||
|
|
||||||
def load_if_low_vram(model):
|
def load_if_low_vram(model):
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
return model.cuda()
|
return model.cuda()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def unload_if_low_vram(model):
|
def unload_if_low_vram(model):
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
if vram_state == MPS:
|
if vram_state == VRAMState.MPS:
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
if vram_state == CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
return torch.cuda.current_device()
|
return torch.cuda.current_device()
|
||||||
@ -201,9 +198,9 @@ def get_autocast_device(dev):
|
|||||||
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
if vram_state == CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
return False
|
return False
|
||||||
return XFORMERS_IS_AVAILBLE
|
return XFORMERS_IS_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
def xformers_enabled_vae():
|
def xformers_enabled_vae():
|
||||||
@ -243,7 +240,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
|
|
||||||
def maximum_batch_area():
|
def maximum_batch_area():
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == NO_VRAM:
|
if vram_state == VRAMState.NO_VRAM:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
memory_free = get_free_memory() / (1024 * 1024)
|
memory_free = get_free_memory() / (1024 * 1024)
|
||||||
@ -252,11 +249,11 @@ def maximum_batch_area():
|
|||||||
|
|
||||||
def cpu_mode():
|
def cpu_mode():
|
||||||
global vram_state
|
global vram_state
|
||||||
return vram_state == CPU
|
return vram_state == VRAMState.CPU
|
||||||
|
|
||||||
def mps_mode():
|
def mps_mode():
|
||||||
global vram_state
|
global vram_state
|
||||||
return vram_state == MPS
|
return vram_state == VRAMState.MPS
|
||||||
|
|
||||||
def should_use_fp16():
|
def should_use_fp16():
|
||||||
if cpu_mode() or mps_mode():
|
if cpu_mode() or mps_mode():
|
||||||
|
27
main.py
27
main.py
@ -1,37 +1,14 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Script Arguments")
|
|
||||||
|
|
||||||
parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 0.0.0.0 if none given so the UI can be accessed from other computers.")
|
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
|
||||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.")
|
|
||||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
|
||||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
|
||||||
parser.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
|
||||||
parser.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.")
|
|
||||||
parser.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
|
||||||
parser.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
|
||||||
parser.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
|
||||||
parser.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
|
||||||
parser.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
|
||||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.dont_upcast_attention:
|
if args.dont_upcast_attention:
|
||||||
print("disabling upcasting of attention")
|
print("disabling upcasting of attention")
|
||||||
os.environ['ATTN_PRECISION'] = "fp16"
|
os.environ['ATTN_PRECISION'] = "fp16"
|
||||||
@ -121,7 +98,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.output_directory:
|
if args.output_directory:
|
||||||
output_dir = os.path.abspath(args.output_directory)
|
output_dir = os.path.abspath(args.output_directory)
|
||||||
print("setting output directory to:", output_dir)
|
print(f"Setting output directory to: {output_dir}")
|
||||||
folder_paths.set_output_directory(output_dir)
|
folder_paths.set_output_directory(output_dir)
|
||||||
|
|
||||||
port = args.port
|
port = args.port
|
||||||
|
Loading…
Reference in New Issue
Block a user