mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 00:23:30 +00:00
Compare commits
7 Commits
261e5b1ef3
...
b3f0019ddb
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b3f0019ddb | ||
![]() |
22ad513c72 | ||
![]() |
ed945a1790 | ||
![]() |
f9207c6936 | ||
![]() |
8ad7477647 | ||
![]() |
af6ce81f00 | ||
![]() |
b83869eb3c |
2
codebeaver.yml
Normal file
2
codebeaver.yml
Normal file
@ -0,0 +1,2 @@
|
||||
from:pytest
|
||||
# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration
|
@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
|
@ -316,3 +316,156 @@ class LRUCache(BasicCache):
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
|
||||
class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
A cache implementation that tracks dependencies between nodes and manages
|
||||
their execution and caching accordingly. It extends the BasicCache class.
|
||||
Nodes are removed from this cache once all of their descendants have been
|
||||
executed.
|
||||
"""
|
||||
|
||||
def __init__(self, key_class):
|
||||
"""
|
||||
Initialize the DependencyAwareCache.
|
||||
|
||||
Args:
|
||||
key_class: The class used for generating cache keys.
|
||||
"""
|
||||
super().__init__(key_class)
|
||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to initialize the cache for.
|
||||
is_changed_cache: Flag indicating if the cache has changed.
|
||||
"""
|
||||
# Clear all existing cache data
|
||||
self.cache.clear()
|
||||
self.subcaches.clear()
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
|
||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||
"""
|
||||
Build the dependency graph for all nodes.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to build the graph for.
|
||||
"""
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
for node_id in node_ids:
|
||||
self.descendants[node_id] = set()
|
||||
self.ancestors[node_id] = set()
|
||||
|
||||
for node_id in node_ids:
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
for input_data in inputs.values():
|
||||
if is_link(input_data): # Check if the input is a link to another node
|
||||
ancestor_id = input_data[0]
|
||||
self.descendants[ancestor_id].add(node_id)
|
||||
self.ancestors[node_id].add(ancestor_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
"""
|
||||
Mark a node as executed and store its value in the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to store.
|
||||
value: The value to store for the node.
|
||||
"""
|
||||
self._set_immediate(node_id, value)
|
||||
self.executed_nodes.add(node_id)
|
||||
self._cleanup_ancestors(node_id)
|
||||
|
||||
def get(self, node_id):
|
||||
"""
|
||||
Retrieve the cached value for a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to retrieve.
|
||||
|
||||
Returns:
|
||||
The cached value for the node.
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the parent node.
|
||||
children_ids: List of child node IDs to associate with the parent node.
|
||||
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
return subcache
|
||||
|
||||
def _cleanup_ancestors(self, node_id):
|
||||
"""
|
||||
Check if ancestors of a node can be removed from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node whose ancestors are to be checked.
|
||||
"""
|
||||
for ancestor_id in self.ancestors.get(node_id, []):
|
||||
if ancestor_id in self.executed_nodes:
|
||||
# Remove ancestor if all its descendants have been executed
|
||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||
self._remove_node(ancestor_id)
|
||||
|
||||
def _remove_node(self, node_id):
|
||||
"""
|
||||
Remove a node from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to remove.
|
||||
"""
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
del self.cache[cache_key]
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
del self.subcaches[subcache_key]
|
||||
|
||||
def clean_unused(self):
|
||||
"""
|
||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
"""
|
||||
Dump the cache and dependency graph for debugging.
|
||||
|
||||
Returns:
|
||||
A list containing the cache state and dependency graph.
|
||||
"""
|
||||
result = super().recursive_debug_dump()
|
||||
result.append({
|
||||
"descendants": self.descendants,
|
||||
"ancestors": self.ancestors,
|
||||
"executed_nodes": list(self.executed_nodes),
|
||||
})
|
||||
return result
|
||||
|
53
execution.py
53
execution.py
@ -15,7 +15,7 @@ import nodes
|
||||
import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.validation import validate_node_input
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@ -59,20 +59,27 @@ class IsChangedCache:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
class CacheType(Enum):
|
||||
CLASSIC = 0
|
||||
LRU = 1
|
||||
DEPENDENCY_AWARE = 2
|
||||
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, cache_type=None, cache_size=None):
|
||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||
self.init_dependency_aware_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
if cache_size is None:
|
||||
cache_size = 0
|
||||
self.init_lru_cache(cache_size)
|
||||
logging.info("Using LRU cache")
|
||||
else:
|
||||
self.init_classic_cache()
|
||||
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
@ -80,6 +87,17 @@ class CacheSet:
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# only hold cached items while the decendents have not executed
|
||||
def init_dependency_aware_cache(self):
|
||||
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.objects = DependencyAwareCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
@ -414,13 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server, lru_size=None):
|
||||
self.lru_size = lru_size
|
||||
def __init__(self, server, cache_type=False, cache_size=None):
|
||||
self.cache_size = cache_size
|
||||
self.cache_type = cache_type
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.caches = CacheSet(self.lru_size)
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
|
8
main.py
8
main.py
@ -156,7 +156,13 @@ def cuda_malloc_warning():
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.14.6
|
||||
comfyui-frontend-package==1.15.13
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
response: web.Response = await handler(request)
|
||||
if request.path.endswith('.js') or request.path.endswith('.css'):
|
||||
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||
return response
|
||||
|
||||
|
381
tests/test_cli_args.py
Normal file
381
tests/test_cli_args.py
Normal file
@ -0,0 +1,381 @@
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import comfy.cli_args as cli_args
|
||||
import importlib
|
||||
import comfy.options
|
||||
|
||||
# No additional imports required since all necessary modules
|
||||
# (pytest, comfy.cli_args, etc.) are already imported in the test file.
|
||||
def test_is_valid_directory(tmp_path):
|
||||
"""
|
||||
Test the is_valid_directory function from comfy.cli_args.
|
||||
Verifies that:
|
||||
- Passing None returns None.
|
||||
- A valid directory returns the same path string.
|
||||
- An invalid directory path raises an argparse.ArgumentTypeError.
|
||||
"""
|
||||
assert cli_args.is_valid_directory(None) is None
|
||||
valid_dir = str(tmp_path)
|
||||
returned = cli_args.is_valid_directory(valid_dir)
|
||||
assert returned == valid_dir
|
||||
invalid_dir = os.path.join(valid_dir, "non_existing_dir")
|
||||
with pytest.raises(argparse.ArgumentTypeError) as excinfo:
|
||||
cli_args.is_valid_directory(invalid_dir)
|
||||
assert invalid_dir in str(excinfo.value)
|
||||
def test_listen_argument_no_value():
|
||||
"""
|
||||
Test that when the '--listen' argument is provided without a following value,
|
||||
the parser uses the const value "0.0.0.0,::" instead of the default.
|
||||
"""
|
||||
test_args = ["--listen"]
|
||||
args = cli_args.parser.parse_args(test_args)
|
||||
assert args.listen == "0.0.0.0,::"
|
||||
def test_preview_method_argument():
|
||||
"""
|
||||
Test that the '--preview-method' argument:
|
||||
- Correctly converts a valid value (e.g. "latent2rgb") to a LatentPreviewMethod enum instance.
|
||||
- Causes the parser to exit with an error (SystemExit) when provided an invalid value.
|
||||
"""
|
||||
valid_value = "latent2rgb"
|
||||
args = cli_args.parser.parse_args(["--preview-method", valid_value])
|
||||
assert args.preview_method == cli_args.LatentPreviewMethod.Latent2RGB
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--preview-method", "invalid_value"])
|
||||
def test_directml_argument():
|
||||
"""
|
||||
Test the '--directml' argument to ensure:
|
||||
- When provided without a value, the default const value (-1) is used.
|
||||
- When provided with an argument, the argument is correctly parsed as an integer.
|
||||
"""
|
||||
args = cli_args.parser.parse_args(["--directml"])
|
||||
assert args.directml == -1
|
||||
args = cli_args.parser.parse_args(["--directml", "5"])
|
||||
assert args.directml == 5
|
||||
def test_extra_model_paths_config_argument():
|
||||
"""
|
||||
Test that the '--extra-model-paths-config' argument is parsed correctly.
|
||||
Verifies that:
|
||||
- When not provided, the default value is None.
|
||||
- When provided once with multiple values, the result is a nested list containing one list.
|
||||
- When provided multiple times, each occurrence is stored as a separate sublist.
|
||||
"""
|
||||
args = cli_args.parser.parse_args([])
|
||||
assert args.extra_model_paths_config is None
|
||||
args = cli_args.parser.parse_args(["--extra-model-paths-config", "a.yaml", "b.yaml"])
|
||||
assert args.extra_model_paths_config == [["a.yaml", "b.yaml"]]
|
||||
args = cli_args.parser.parse_args([
|
||||
"--extra-model-paths-config", "a.yaml", "b.yaml",
|
||||
"--extra-model-paths-config", "c.yaml"
|
||||
])
|
||||
assert args.extra_model_paths_config == [["a.yaml", "b.yaml"], ["c.yaml"]]
|
||||
def test_windows_standalone_build_flag():
|
||||
"""
|
||||
Test that the '--windows-standalone-build' flag correctly sets auto_launch to True,
|
||||
and that when both '--windows-standalone-build' and '--disable-auto-launch' are provided,
|
||||
auto_launch becomes False. This test manually applies the module-level post-processing logic.
|
||||
"""
|
||||
def post_process_args(ns):
|
||||
if ns.windows_standalone_build:
|
||||
ns.auto_launch = True
|
||||
if ns.disable_auto_launch:
|
||||
ns.auto_launch = False
|
||||
return ns
|
||||
args = cli_args.parser.parse_args(["--windows-standalone-build"])
|
||||
args = post_process_args(args)
|
||||
assert args.auto_launch is True
|
||||
args = cli_args.parser.parse_args(["--windows-standalone-build", "--disable-auto-launch"])
|
||||
args = post_process_args(args)
|
||||
assert args.auto_launch is False
|
||||
def test_verbose_argument():
|
||||
"""
|
||||
Test that the '--verbose' argument works correctly:
|
||||
- When provided without a value, it should default to 'DEBUG' (using its const value).
|
||||
- When provided with an explicit value, the given value should be used.
|
||||
"""
|
||||
args = cli_args.parser.parse_args(["--verbose"])
|
||||
assert args.verbose == "DEBUG"
|
||||
args = cli_args.parser.parse_args(["--verbose", "WARNING"])
|
||||
assert args.verbose == "WARNING"
|
||||
def test_mutually_exclusive_cuda_malloc():
|
||||
"""
|
||||
Test that providing both mutually exclusive options '--cuda-malloc' and
|
||||
'--disable-cuda-malloc' raises a SystemExit due to the conflict.
|
||||
"""
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--cuda-malloc", "--disable-cuda-malloc"])
|
||||
def test_front_end_version_argument():
|
||||
"""
|
||||
Test that the '--front-end-version' argument:
|
||||
- Defaults to "comfyanonymous/ComfyUI@latest" when not provided.
|
||||
- Accepts and correctly parses a custom version string when provided.
|
||||
"""
|
||||
args = cli_args.parser.parse_args([])
|
||||
assert args.front_end_version == "comfyanonymous/ComfyUI@latest"
|
||||
custom_version = "user/custom@1.2.3"
|
||||
args = cli_args.parser.parse_args(["--front-end-version", custom_version])
|
||||
assert args.front_end_version == custom_version
|
||||
def test_mutually_exclusive_fpvae_group():
|
||||
"""
|
||||
Test that providing both mutually exclusive '--fp16-vae' and '--bf16-vae'
|
||||
arguments causes the parser to exit with an error.
|
||||
"""
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--fp16-vae", "--bf16-vae"])
|
||||
def test_default_values():
|
||||
"""
|
||||
Test that default values for all arguments are correctly set when no arguments are provided.
|
||||
This verifies the defaults for network settings, directory paths, various flags, and numeric options.
|
||||
"""
|
||||
args = cli_args.parser.parse_args([])
|
||||
assert args.listen == "127.0.0.1"
|
||||
assert args.port == 8188
|
||||
assert args.tls_keyfile is None
|
||||
assert args.tls_certfile is None
|
||||
assert args.enable_cors_header is None
|
||||
assert args.max_upload_size == 100
|
||||
assert args.base_directory is None
|
||||
assert args.output_directory is None
|
||||
assert args.temp_directory is None
|
||||
assert args.input_directory is None
|
||||
assert args.auto_launch is False
|
||||
assert args.disable_auto_launch is False
|
||||
assert args.cuda_device is None
|
||||
assert not getattr(args, 'cuda_malloc', False)
|
||||
assert not getattr(args, 'disable_cuda_malloc', False)
|
||||
assert not getattr(args, 'fp32_unet', False)
|
||||
assert not getattr(args, 'fp64_unet', False)
|
||||
assert not getattr(args, 'bf16_unet', False)
|
||||
assert not getattr(args, 'fp16_unet', False)
|
||||
assert not getattr(args, 'fp8_e4m3fn_unet', False)
|
||||
assert not getattr(args, 'fp8_e5m2_unet', False)
|
||||
assert not getattr(args, 'fp16_vae', False)
|
||||
assert not getattr(args, 'fp32_vae', False)
|
||||
assert not getattr(args, 'bf16_vae', False)
|
||||
assert not args.cpu_vae
|
||||
assert not getattr(args, 'fp8_e4m3fn_text_enc', False)
|
||||
assert not getattr(args, 'fp8_e5m2_text_enc', False)
|
||||
assert not getattr(args, 'fp16_text_enc', False)
|
||||
assert not getattr(args, 'fp32_text_enc', False)
|
||||
assert not getattr(args, 'force_upcast_attention', False)
|
||||
assert not getattr(args, 'dont_upcast_attention', False)
|
||||
assert not getattr(args, 'gpu_only', False)
|
||||
assert not getattr(args, 'highvram', False)
|
||||
assert not getattr(args, 'normalvram', False)
|
||||
assert not getattr(args, 'lowvram', False)
|
||||
assert not getattr(args, 'novram', False)
|
||||
assert not getattr(args, 'cpu', False)
|
||||
assert args.reserve_vram is None
|
||||
assert args.default_hashing_function == 'sha256'
|
||||
assert not args.disable_smart_memory
|
||||
assert not args.deterministic
|
||||
assert not args.fast
|
||||
assert args.verbose == 'INFO'
|
||||
assert not args.log_stdout
|
||||
assert args.front_end_version == "comfyanonymous/ComfyUI@latest"
|
||||
assert args.front_end_root is None
|
||||
assert args.user_directory is None
|
||||
assert not args.enable_compress_response_body
|
||||
def test_oneapi_device_selector_argument():
|
||||
"""
|
||||
Test that the '--oneapi-device-selector' argument is correctly parsed.
|
||||
Verifies that:
|
||||
- When not provided, the default value is None.
|
||||
- When provided with a specific string, the argument returns that string.
|
||||
"""
|
||||
args = cli_args.parser.parse_args([])
|
||||
assert args.oneapi_device_selector is None, "Default for oneapi-device-selector should be None"
|
||||
test_value = "GPU0,GPU1"
|
||||
args = cli_args.parser.parse_args(["--oneapi-device-selector", test_value])
|
||||
assert args.oneapi_device_selector == test_value, f"Expected oneapi-device-selector to be {test_value}"
|
||||
def test_tls_arguments():
|
||||
"""
|
||||
Test that TLS related arguments are correctly parsed.
|
||||
Verifies that:
|
||||
- When provided, the '--tls-keyfile' and '--tls-certfile' arguments are correctly stored.
|
||||
"""
|
||||
test_args = ["--tls-keyfile", "keyfile.pem", "--tls-certfile", "certfile.pem"]
|
||||
args = cli_args.parser.parse_args(test_args)
|
||||
assert args.tls_keyfile == "keyfile.pem"
|
||||
assert args.tls_certfile == "certfile.pem"
|
||||
def test_invalid_directml_argument():
|
||||
"""
|
||||
Test that providing a non-integer value for the '--directml' argument
|
||||
raises a SystemExit error due to the argparse type conversion failure.
|
||||
"""
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--directml", "not_a_number"])
|
||||
def test_mutually_exclusive_fpte_group():
|
||||
"""
|
||||
Test that providing mutually exclusive text encoder precision options
|
||||
(e.g. '--fp8_e4m3fn-text-enc' and '--fp16-text-enc') raises a SystemExit error.
|
||||
"""
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--fp8_e4m3fn-text-enc", "--fp16-text-enc"])
|
||||
def test_miscellaneous_flags():
|
||||
"""
|
||||
Test that miscellaneous boolean flags (disable-metadata, disable-all-custom-nodes, multi-user, and log-stdout)
|
||||
are correctly set when provided on the command line.
|
||||
"""
|
||||
args = cli_args.parser.parse_args([
|
||||
"--disable-metadata",
|
||||
"--disable-all-custom-nodes",
|
||||
"--multi-user",
|
||||
"--log-stdout"
|
||||
])
|
||||
assert args.disable_metadata is True, "Expected --disable-metadata to set disable_metadata to True"
|
||||
assert args.disable_all_custom_nodes is True, "Expected --disable-all-custom-nodes to set disable_all_custom_nodes to True"
|
||||
assert args.multi_user is True, "Expected --multi-user to set multi_user to True"
|
||||
assert args.log_stdout is True, "Expected --log-stdout to set log_stdout to True"
|
||||
def test_invalid_hashing_function_argument():
|
||||
"""
|
||||
Test that providing an invalid value for the '--default-hashing-function' argument
|
||||
raises a SystemExit error due to the invalid choice.
|
||||
"""
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--default-hashing-function", "invalid_hash"])
|
||||
def test_front_end_root_argument(tmp_path):
|
||||
"""
|
||||
Test that the '--front-end-root' argument correctly validates and returns
|
||||
the provided directory path as a string.
|
||||
"""
|
||||
valid_dir = str(tmp_path)
|
||||
args = cli_args.parser.parse_args(["--front-end-root", valid_dir])
|
||||
assert args.front_end_root == valid_dir
|
||||
def test_enable_cors_header_argument():
|
||||
"""
|
||||
Test that the '--enable-cors-header' argument is parsed correctly:
|
||||
- When provided without a value, it should use the const value "*"
|
||||
(indicating allow all origins).
|
||||
- When provided with an explicit value, it returns that value.
|
||||
"""
|
||||
args = cli_args.parser.parse_args(["--enable-cors-header"])
|
||||
assert args.enable_cors_header == "*", "Expected --enable-cors-header with no value to default to '*'"
|
||||
test_origin = "http://example.com"
|
||||
args = cli_args.parser.parse_args(["--enable-cors-header", test_origin])
|
||||
assert args.enable_cors_header == test_origin, "Expected --enable-cors-header to use the provided origin"
|
||||
def test_listen_argument_with_explicit_value():
|
||||
"""
|
||||
Test that providing an explicit IP address with '--listen' sets the value correctly.
|
||||
"""
|
||||
test_ip = "192.168.1.100"
|
||||
args = cli_args.parser.parse_args(["--listen", test_ip])
|
||||
assert args.listen == test_ip
|
||||
def test_cache_arguments():
|
||||
"""
|
||||
Test that the caching arguments are correctly parsed.
|
||||
Verifies that:
|
||||
- When provided with '--cache-lru' and an integer value, the cache_lru attribute is set accordingly and cache_classic is False.
|
||||
- When provided with '--cache-classic', the cache_classic flag is True and cache_lru remains at its default value (0).
|
||||
- Providing both options simultaneously raises a SystemExit error due to mutual exclusivity.
|
||||
"""
|
||||
args = cli_args.parser.parse_args(["--cache-lru", "15"])
|
||||
assert args.cache_lru == 15
|
||||
assert args.cache_classic is False
|
||||
args = cli_args.parser.parse_args(["--cache-classic"])
|
||||
assert args.cache_classic is True
|
||||
assert args.cache_lru == 0
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--cache-lru", "15", "--cache-classic"])
|
||||
def test_invalid_port_argument():
|
||||
"""
|
||||
Test that passing a non-integer value to the '--port' argument
|
||||
raises a SystemExit error due to type conversion failure.
|
||||
"""
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--port", "not_an_integer"])
|
||||
def test_mutually_exclusive_fpunet_group():
|
||||
"""
|
||||
Test that providing mutually exclusive unet precision options (e.g. '--fp32-unet' and '--fp16-unet')
|
||||
raises a SystemExit error due to the conflict in the fpunet group.
|
||||
"""
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--fp32-unet", "--fp16-unet"])
|
||||
def test_mutually_exclusive_upcast_arguments():
|
||||
"""
|
||||
Test the mutual exclusivity of the upcast arguments:
|
||||
--force-upcast-attention and --dont-upcast-attention.
|
||||
|
||||
This test verifies that:
|
||||
- Providing only --force-upcast-attention sets its flag to True while --dont-upcast-attention remains False.
|
||||
- Providing only --dont-upcast-attention sets its flag to True while --force-upcast-attention remains False.
|
||||
- Providing both flags simultaneously raises a SystemExit error.
|
||||
"""
|
||||
args = cli_args.parser.parse_args(["--force-upcast-attention"])
|
||||
assert getattr(args, "force_upcast_attention", False) is True
|
||||
assert getattr(args, "dont_upcast_attention", False) is False
|
||||
args = cli_args.parser.parse_args(["--dont-upcast-attention"])
|
||||
assert getattr(args, "dont_upcast_attention", False) is True
|
||||
assert getattr(args, "force_upcast_attention", False) is False
|
||||
with pytest.raises(SystemExit):
|
||||
cli_args.parser.parse_args(["--force-upcast-attention", "--dont-upcast-attention"])
|
||||
def test_disable_xformers_argument():
|
||||
"""
|
||||
Test that the '--disable-xformers' argument correctly sets the disable_xformers flag to True.
|
||||
"""
|
||||
args = cli_args.parser.parse_args(["--disable-xformers"])
|
||||
assert args.disable_xformers is True
|
||||
def test_module_args_parsing_behavior():
|
||||
"""
|
||||
Test the module-level args parsing behavior.
|
||||
This test temporarily sets comfy.options.args_parsing to False and reloads the cli_args module.
|
||||
When args_parsing is False, the parser should be forced to parse an empty list (using defaults).
|
||||
We verify that the globals in cli_args.args have the expected default values.
|
||||
"""
|
||||
# Save the original value
|
||||
original_args_parsing = comfy.options.args_parsing
|
||||
try:
|
||||
# Set args_parsing to False so that parser.parse_args([]) is used at the module level
|
||||
comfy.options.args_parsing = False
|
||||
reloaded_cli_args = importlib.reload(cli_args)
|
||||
# Since no arguments are provided by default, the defaults should be assigned.
|
||||
# For instance, 'listen' should be "127.0.0.1" and 'port' should be 8188.
|
||||
assert reloaded_cli_args.args.listen == "127.0.0.1", "Expected default listen to be 127.0.0.1"
|
||||
assert reloaded_cli_args.args.port == 8188, "Expected default port to be 8188"
|
||||
# Additionally, we want to check that a flag (like auto_launch) remains at its default (False)
|
||||
assert reloaded_cli_args.args.auto_launch is False, "Expected auto_launch to be False when no args provided"
|
||||
finally:
|
||||
# Restore the original args_parsing value and reload the module again to reset state.
|
||||
comfy.options.args_parsing = original_args_parsing
|
||||
importlib.reload(cli_args) # reset to original state
|
||||
def test_fpunet_valid_option():
|
||||
"""
|
||||
Test that providing the '--fp16-unet' argument correctly sets the fp16_unet flag,
|
||||
and does not set any other mutually exclusive fpunet options.
|
||||
"""
|
||||
args = cli_args.parser.parse_args(["--fp16-unet"])
|
||||
assert args.fp16_unet is True, "Expected --fp16-unet to set fp16_unet to True"
|
||||
assert not getattr(args, "fp32_unet", False), "Unexpected --fp32-unet flag when --fp16-unet is provided"
|
||||
assert not getattr(args, "fp64_unet", False), "Unexpected --fp64-unet flag when --fp16-unet is provided"
|
||||
assert not getattr(args, "bf16_unet", False), "Unexpected --bf16-unet flag when --fp16-unet is provided"
|
||||
assert not getattr(args, "fp8_e4m3fn_unet", False), "Unexpected --fp8_e4m3fn-unet flag when --fp16-unet is provided"
|
||||
assert not getattr(args, "fp8_e5m2_unet", False), "Unexpected --fp8_e5m2-unet flag when --fp16-unet is provided"
|
||||
def test_dont_print_server_and_quick_test_for_ci():
|
||||
"""
|
||||
Test that the '--dont-print-server' and '--quick-test-for-ci' flags are correctly parsed.
|
||||
Verifies that when not provided, their values default to False, and when provided,
|
||||
they are set to True.
|
||||
"""
|
||||
# Test with no flags provided (should use defaults)
|
||||
args = cli_args.parser.parse_args([])
|
||||
assert args.dont_print_server is False, "Expected default dont_print_server to be False"
|
||||
assert args.quick_test_for_ci is False, "Expected default quick_test_for_ci to be False"
|
||||
|
||||
# Test with both flags provided on the command line
|
||||
args = cli_args.parser.parse_args(["--dont-print-server", "--quick-test-for-ci"])
|
||||
assert args.dont_print_server is True, "Expected dont_print_server to be True when flag is provided"
|
||||
assert args.quick_test_for_ci is True, "Expected quick_test_for_ci to be True when flag is provided"
|
||||
def test_vram_group_lowvram_flag():
|
||||
"""
|
||||
Test that providing the '--lowvram' flag correctly sets the lowvram flag
|
||||
and ensures that all other mutually exclusive vram flags remain False.
|
||||
"""
|
||||
args = cli_args.parser.parse_args(["--lowvram"])
|
||||
assert args.lowvram is True
|
||||
# Ensure that no other mutually exclusive vram flags are set
|
||||
assert not getattr(args, "gpu_only", False)
|
||||
assert not getattr(args, "highvram", False)
|
||||
assert not getattr(args, "normalvram", False)
|
||||
assert not getattr(args, "novram", False)
|
||||
assert not getattr(args, "cpu", False)
|
Loading…
Reference in New Issue
Block a user