Compare commits

...

6 Commits

Author SHA1 Message Date
silveroxides
cd3d2d5c62 remove unused imports 2025-04-11 16:22:04 +02:00
silveroxides
2c2481955d reduce code duplication 2025-04-11 16:18:57 +02:00
Silver
fc978a7ad8
Merge branch 'comfyanonymous:master' into chroma-support 2025-04-11 13:07:36 +02:00
Chargeuk
ed945a1790
Dependency Aware Node Caching for low RAM/VRAM machines (#7509)
* add dependency aware cache that removed a cached node as soon as all of its decendents have executed. This allows users with lower RAM to run workflows they would otherwise not be able to run. The downside is that every workflow will fully run each time even if no nodes have changed.

* remove test code

* tidy code
2025-04-11 06:55:51 -04:00
Chenlei Hu
f9207c6936
Update frontend to 1.15 (#7564) 2025-04-11 06:46:20 -04:00
Christian Byrne
8ad7477647
dont cache templates index (#7569) 2025-04-11 06:06:53 -04:00
8 changed files with 197 additions and 136 deletions

View File

@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
cache_group = parser.add_mutually_exclusive_group() cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@ -1,119 +1,29 @@
import math
from dataclasses import dataclass
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from .math import attention, rope from .math import attention
import comfy.ops from comfy.ldm.flux.layers import (
import comfy.ldm.common_dit MLPEmbedder,
RMSNorm,
QKNorm,
class EmbedND(nn.Module): SelfAttention,
def __init__(self, dim: int, theta: int, axes_dim: list): ModulationOut,
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
) )
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): class ChromaModulationOut(ModulationOut):
""" @classmethod
Create sinusoidal timestep embeddings. def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
:param t: a 1-D Tensor of N indices, one per batch element. return cls(
These may be fractional. shift=tensor[:, offset : offset + 1, :],
:param dim: the dimension of the output. scale=tensor[:, offset + 1 : offset + 2, :],
:param max_period: controls the minimum frequency of the embeddings. gate=tensor[:, offset + 2 : offset + 3, :],
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.silu = nn.SiLU()
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
) )
class Approximator(nn.Module): class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None): def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
super().__init__() super().__init__()

View File

@ -7,15 +7,17 @@ from torch import Tensor, nn
from einops import rearrange, repeat from einops import rearrange, repeat
import comfy.ldm.common_dit import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
)
from .layers import ( from .layers import (
DoubleStreamBlock, DoubleStreamBlock,
EmbedND,
LastLayer, LastLayer,
MLPEmbedder,
SingleStreamBlock, SingleStreamBlock,
timestep_embedding,
Approximator, Approximator,
ModulationOut ChromaModulationOut,
) )
@ -39,14 +41,6 @@ class ChromaParams:
n_layers: int n_layers: int
class ChromaModulationOut(ModulationOut):
@classmethod
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
return cls(
shift=tensor[:, offset : offset + 1, :],
scale=tensor[:, offset + 1 : offset + 2, :],
gate=tensor[:, offset + 2 : offset + 3, :],
)
class Chroma(nn.Module): class Chroma(nn.Module):
@ -77,7 +71,6 @@ class Chroma(nn.Module):
self.n_layers = params.n_layers self.n_layers = params.n_layers
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
# set as nn identity for now, will overwrite it later. # set as nn identity for now, will overwrite it later.
self.distilled_guidance_layer = Approximator( self.distilled_guidance_layer = Approximator(
@ -88,9 +81,6 @@ class Chroma(nn.Module):
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
self.guidance_in = (
MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if self.distilled_guidance_layer else nn.Identity()
)
self.double_blocks = nn.ModuleList( self.double_blocks = nn.ModuleList(
[ [

View File

@ -316,3 +316,156 @@ class LRUCache(BasicCache):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self return self
class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""
def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.
Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()
for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)
def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.
Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)
def get(self, node_id):
"""
Retrieve the cached value for a node.
Args:
node_id: The ID of the node to retrieve.
Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.
Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache
def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.
Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)
def _remove_node(self, node_id):
"""
Remove a node from the cache.
Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]
def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass
def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.
Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result

View File

@ -15,7 +15,7 @@ import nodes
import comfy.model_management import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -60,26 +60,32 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
class CacheSet: class CacheSet:
def __init__(self, lru_size=None): def __init__(self, lru_size=None, cache_none=False):
if lru_size is None or lru_size == 0: if cache_none:
self.init_dependency_aware_cache()
elif lru_size is None or lru_size == 0:
self.init_classic_cache() self.init_classic_cache()
else: else:
self.init_lru_cache(lru_size) self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects] self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# Performs like the old cache -- dump data ASAP # Performs like the old cache -- dump data ASAP
def init_classic_cache(self): def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature) self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def recursive_debug_dump(self): def recursive_debug_dump(self):
result = { result = {
"outputs": self.outputs.recursive_debug_dump(), "outputs": self.outputs.recursive_debug_dump(),
@ -414,13 +420,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor: class PromptExecutor:
def __init__(self, server, lru_size=None): def __init__(self, server, lru_size=None, cache_none=False):
self.lru_size = lru_size self.lru_size = lru_size
self.cache_none = cache_none
self.server = server self.server = server
self.reset() self.reset()
def reset(self): def reset(self):
self.caches = CacheSet(self.lru_size) self.caches = CacheSet(self.lru_size, self.cache_none)
self.status_messages = [] self.status_messages = []
self.success = True self.success = True

View File

@ -156,7 +156,7 @@ def cuda_malloc_warning():
def prompt_worker(q, server_instance): def prompt_worker(q, server_instance):
current_time: float = 0.0 current_time: float = 0.0
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru) e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru, cache_none=args.cache_none)
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.14.6 comfyui-frontend-package==1.15.13
torch torch
torchsde torchsde
torchvision torchvision

View File

@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request) response: web.Response = await handler(request)
if request.path.endswith('.js') or request.path.endswith('.css'): if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
response.headers.setdefault('Cache-Control', 'no-cache') response.headers.setdefault('Cache-Control', 'no-cache')
return response return response