mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 08:33:29 +00:00
Compare commits
11 Commits
6cc6ffdfdd
...
258e91d20c
Author | SHA1 | Date | |
---|---|---|---|
![]() |
258e91d20c | ||
![]() |
22ad513c72 | ||
![]() |
ed945a1790 | ||
![]() |
f9207c6936 | ||
![]() |
8ad7477647 | ||
![]() |
618a7a3fea | ||
![]() |
c868cb2055 | ||
![]() |
9f9db7fc29 | ||
![]() |
ae0b0da8b8 | ||
![]() |
bdc8c2a8c7 | ||
![]() |
0a72baba13 |
@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
|
|||||||
cache_group = parser.add_mutually_exclusive_group()
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
|
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
|
@ -316,3 +316,156 @@ class LRUCache(BasicCache):
|
|||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyAwareCache(BasicCache):
|
||||||
|
"""
|
||||||
|
A cache implementation that tracks dependencies between nodes and manages
|
||||||
|
their execution and caching accordingly. It extends the BasicCache class.
|
||||||
|
Nodes are removed from this cache once all of their descendants have been
|
||||||
|
executed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key_class):
|
||||||
|
"""
|
||||||
|
Initialize the DependencyAwareCache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_class: The class used for generating cache keys.
|
||||||
|
"""
|
||||||
|
super().__init__(key_class)
|
||||||
|
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||||
|
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||||
|
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||||
|
|
||||||
|
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
"""
|
||||||
|
Clear the entire cache and rebuild the dependency graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dynprompt: The dynamic prompt object containing node information.
|
||||||
|
node_ids: List of node IDs to initialize the cache for.
|
||||||
|
is_changed_cache: Flag indicating if the cache has changed.
|
||||||
|
"""
|
||||||
|
# Clear all existing cache data
|
||||||
|
self.cache.clear()
|
||||||
|
self.subcaches.clear()
|
||||||
|
self.descendants.clear()
|
||||||
|
self.ancestors.clear()
|
||||||
|
self.executed_nodes.clear()
|
||||||
|
|
||||||
|
# Call the parent method to initialize the cache with the new prompt
|
||||||
|
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||||
|
|
||||||
|
# Rebuild the dependency graph
|
||||||
|
self._build_dependency_graph(dynprompt, node_ids)
|
||||||
|
|
||||||
|
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||||
|
"""
|
||||||
|
Build the dependency graph for all nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dynprompt: The dynamic prompt object containing node information.
|
||||||
|
node_ids: List of node IDs to build the graph for.
|
||||||
|
"""
|
||||||
|
self.descendants.clear()
|
||||||
|
self.ancestors.clear()
|
||||||
|
for node_id in node_ids:
|
||||||
|
self.descendants[node_id] = set()
|
||||||
|
self.ancestors[node_id] = set()
|
||||||
|
|
||||||
|
for node_id in node_ids:
|
||||||
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
|
for input_data in inputs.values():
|
||||||
|
if is_link(input_data): # Check if the input is a link to another node
|
||||||
|
ancestor_id = input_data[0]
|
||||||
|
self.descendants[ancestor_id].add(node_id)
|
||||||
|
self.ancestors[node_id].add(ancestor_id)
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
"""
|
||||||
|
Mark a node as executed and store its value in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the node to store.
|
||||||
|
value: The value to store for the node.
|
||||||
|
"""
|
||||||
|
self._set_immediate(node_id, value)
|
||||||
|
self.executed_nodes.add(node_id)
|
||||||
|
self._cleanup_ancestors(node_id)
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
"""
|
||||||
|
Retrieve the cached value for a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the node to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The cached value for the node.
|
||||||
|
"""
|
||||||
|
return self._get_immediate(node_id)
|
||||||
|
|
||||||
|
def ensure_subcache_for(self, node_id, children_ids):
|
||||||
|
"""
|
||||||
|
Ensure a subcache exists for a node and update dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the parent node.
|
||||||
|
children_ids: List of child node IDs to associate with the parent node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The subcache object for the node.
|
||||||
|
"""
|
||||||
|
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||||
|
for child_id in children_ids:
|
||||||
|
self.descendants[node_id].add(child_id)
|
||||||
|
self.ancestors[child_id].add(node_id)
|
||||||
|
return subcache
|
||||||
|
|
||||||
|
def _cleanup_ancestors(self, node_id):
|
||||||
|
"""
|
||||||
|
Check if ancestors of a node can be removed from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the node whose ancestors are to be checked.
|
||||||
|
"""
|
||||||
|
for ancestor_id in self.ancestors.get(node_id, []):
|
||||||
|
if ancestor_id in self.executed_nodes:
|
||||||
|
# Remove ancestor if all its descendants have been executed
|
||||||
|
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||||
|
self._remove_node(ancestor_id)
|
||||||
|
|
||||||
|
def _remove_node(self, node_id):
|
||||||
|
"""
|
||||||
|
Remove a node from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The ID of the node to remove.
|
||||||
|
"""
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
if cache_key in self.cache:
|
||||||
|
del self.cache[cache_key]
|
||||||
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
|
if subcache_key in self.subcaches:
|
||||||
|
del self.subcaches[subcache_key]
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
"""
|
||||||
|
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def recursive_debug_dump(self):
|
||||||
|
"""
|
||||||
|
Dump the cache and dependency graph for debugging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list containing the cache state and dependency graph.
|
||||||
|
"""
|
||||||
|
result = super().recursive_debug_dump()
|
||||||
|
result.append({
|
||||||
|
"descendants": self.descendants,
|
||||||
|
"ancestors": self.ancestors,
|
||||||
|
"executed_nodes": list(self.executed_nodes),
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy.k_diffusion.sampling import default_noise_sampler
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
@ -55,6 +56,70 @@ class SamplerLCMUpscale:
|
|||||||
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_lcm_scalewise(model, x, sigmas, extra_args=None, callback=None, disable=None, upscales=None, upscale_method="bicubic"):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
|
||||||
|
if upscales is not None:
|
||||||
|
# Resolution is increased on each step except the last one
|
||||||
|
assert len(upscales) == len(sigmas) - 2
|
||||||
|
|
||||||
|
orig_shape = x.size()
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
|
x = denoised
|
||||||
|
if i < len(upscales):
|
||||||
|
x = comfy.utils.common_upscale(x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled")
|
||||||
|
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
# Since the size of noise if changing, noise_sampler has to be redefined each time
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed)
|
||||||
|
# Noise using the model's scheduler
|
||||||
|
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerLCMScalewise:
|
||||||
|
upscale_methods = ["bicubic", "bilinear", "nearest-exact"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required":
|
||||||
|
{
|
||||||
|
"upscales": ("STRING", {"default": ""}),
|
||||||
|
"upscale_method": (s.upscale_methods,),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SAMPLER",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/samplers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sampler"
|
||||||
|
|
||||||
|
def _validate_upscales(self, upscales):
|
||||||
|
if not upscales:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i in range(1, len(upscales)):
|
||||||
|
if upscales[i] < upscales[i-1]:
|
||||||
|
raise ValueError("`upscales` is expected to be non-decreasing sequence of numbers")
|
||||||
|
|
||||||
|
def get_sampler(self, upscales, upscale_method):
|
||||||
|
# Turn comma-separated list into string
|
||||||
|
upscales = [float(value) for value in upscales.split(',')]
|
||||||
|
self._validate_upscales(upscales)
|
||||||
|
if len(upscales) == 0:
|
||||||
|
upscales = None
|
||||||
|
sampler = comfy.samplers.KSAMPLER(sample_lcm_scalewise, extra_options={"upscales": upscales, "upscale_method": upscale_method})
|
||||||
|
return (sampler, )
|
||||||
|
|
||||||
|
|
||||||
from comfy.k_diffusion.sampling import to_d
|
from comfy.k_diffusion.sampling import to_d
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
|
||||||
@ -103,6 +168,7 @@ class SamplerEulerCFGpp:
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SamplerLCMUpscale": SamplerLCMUpscale,
|
"SamplerLCMUpscale": SamplerLCMUpscale,
|
||||||
|
"SamplerLCMScalewise": SamplerLCMScalewise,
|
||||||
"SamplerEulerCFGpp": SamplerEulerCFGpp,
|
"SamplerEulerCFGpp": SamplerEulerCFGpp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
53
execution.py
53
execution.py
@ -15,7 +15,7 @@ import nodes
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -59,20 +59,27 @@ class IsChangedCache:
|
|||||||
self.is_changed[node_id] = node["is_changed"]
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
class CacheSet:
|
|
||||||
def __init__(self, lru_size=None):
|
|
||||||
if lru_size is None or lru_size == 0:
|
|
||||||
self.init_classic_cache()
|
|
||||||
else:
|
|
||||||
self.init_lru_cache(lru_size)
|
|
||||||
self.all = [self.outputs, self.ui, self.objects]
|
|
||||||
|
|
||||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
class CacheType(Enum):
|
||||||
# blowing away the cache every time
|
CLASSIC = 0
|
||||||
def init_lru_cache(self, cache_size):
|
LRU = 1
|
||||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
DEPENDENCY_AWARE = 2
|
||||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
|
||||||
|
class CacheSet:
|
||||||
|
def __init__(self, cache_type=None, cache_size=None):
|
||||||
|
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||||
|
self.init_dependency_aware_cache()
|
||||||
|
logging.info("Disabling intermediate node cache.")
|
||||||
|
elif cache_type == CacheType.LRU:
|
||||||
|
if cache_size is None:
|
||||||
|
cache_size = 0
|
||||||
|
self.init_lru_cache(cache_size)
|
||||||
|
logging.info("Using LRU cache")
|
||||||
|
else:
|
||||||
|
self.init_classic_cache()
|
||||||
|
|
||||||
|
self.all = [self.outputs, self.ui, self.objects]
|
||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
@ -80,6 +87,17 @@ class CacheSet:
|
|||||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
def init_lru_cache(self, cache_size):
|
||||||
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
# only hold cached items while the decendents have not executed
|
||||||
|
def init_dependency_aware_cache(self):
|
||||||
|
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
||||||
|
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
||||||
|
self.objects = DependencyAwareCache(CacheKeySetID)
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
def recursive_debug_dump(self):
|
||||||
result = {
|
result = {
|
||||||
"outputs": self.outputs.recursive_debug_dump(),
|
"outputs": self.outputs.recursive_debug_dump(),
|
||||||
@ -414,13 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server, lru_size=None):
|
def __init__(self, server, cache_type=False, cache_size=None):
|
||||||
self.lru_size = lru_size
|
self.cache_size = cache_size
|
||||||
|
self.cache_type = cache_type
|
||||||
self.server = server
|
self.server = server
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.caches = CacheSet(self.lru_size)
|
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
|
||||||
|
8
main.py
8
main.py
@ -156,7 +156,13 @@ def cuda_malloc_warning():
|
|||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
cache_type = execution.CacheType.CLASSIC
|
||||||
|
if args.cache_lru > 0:
|
||||||
|
cache_type = execution.CacheType.LRU
|
||||||
|
elif args.cache_none:
|
||||||
|
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
||||||
|
|
||||||
|
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.14.6
|
comfyui-frontend-package==1.15.13
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
|
|||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
if request.path.endswith('.js') or request.path.endswith('.css'):
|
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
||||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user