mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 08:03:30 +00:00
Compare commits
6 Commits
1c76711441
...
cd3d2d5c62
Author | SHA1 | Date | |
---|---|---|---|
![]() |
cd3d2d5c62 | ||
![]() |
2c2481955d | ||
![]() |
fc978a7ad8 | ||
![]() |
ed945a1790 | ||
![]() |
f9207c6936 | ||
![]() |
8ad7477647 |
@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
|
@ -1,116 +1,26 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
from .math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@classmethod
|
||||
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
||||
return cls(
|
||||
shift=tensor[:, offset : offset + 1, :],
|
||||
scale=tensor[:, offset + 1 : offset + 2, :],
|
||||
gate=tensor[:, offset + 2 : offset + 3, :],
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
@ -7,15 +7,17 @@ from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Approximator,
|
||||
ModulationOut
|
||||
ChromaModulationOut,
|
||||
)
|
||||
|
||||
|
||||
@ -39,14 +41,6 @@ class ChromaParams:
|
||||
n_layers: int
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@classmethod
|
||||
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
||||
return cls(
|
||||
shift=tensor[:, offset : offset + 1, :],
|
||||
scale=tensor[:, offset + 1 : offset + 2, :],
|
||||
gate=tensor[:, offset + 2 : offset + 3, :],
|
||||
)
|
||||
|
||||
|
||||
class Chroma(nn.Module):
|
||||
@ -77,7 +71,6 @@ class Chroma(nn.Module):
|
||||
self.n_layers = params.n_layers
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
# set as nn identity for now, will overwrite it later.
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
@ -88,9 +81,6 @@ class Chroma(nn.Module):
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if self.distilled_guidance_layer else nn.Identity()
|
||||
)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
|
@ -316,3 +316,156 @@ class LRUCache(BasicCache):
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
|
||||
class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
A cache implementation that tracks dependencies between nodes and manages
|
||||
their execution and caching accordingly. It extends the BasicCache class.
|
||||
Nodes are removed from this cache once all of their descendants have been
|
||||
executed.
|
||||
"""
|
||||
|
||||
def __init__(self, key_class):
|
||||
"""
|
||||
Initialize the DependencyAwareCache.
|
||||
|
||||
Args:
|
||||
key_class: The class used for generating cache keys.
|
||||
"""
|
||||
super().__init__(key_class)
|
||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to initialize the cache for.
|
||||
is_changed_cache: Flag indicating if the cache has changed.
|
||||
"""
|
||||
# Clear all existing cache data
|
||||
self.cache.clear()
|
||||
self.subcaches.clear()
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
|
||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||
"""
|
||||
Build the dependency graph for all nodes.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to build the graph for.
|
||||
"""
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
for node_id in node_ids:
|
||||
self.descendants[node_id] = set()
|
||||
self.ancestors[node_id] = set()
|
||||
|
||||
for node_id in node_ids:
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
for input_data in inputs.values():
|
||||
if is_link(input_data): # Check if the input is a link to another node
|
||||
ancestor_id = input_data[0]
|
||||
self.descendants[ancestor_id].add(node_id)
|
||||
self.ancestors[node_id].add(ancestor_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
"""
|
||||
Mark a node as executed and store its value in the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to store.
|
||||
value: The value to store for the node.
|
||||
"""
|
||||
self._set_immediate(node_id, value)
|
||||
self.executed_nodes.add(node_id)
|
||||
self._cleanup_ancestors(node_id)
|
||||
|
||||
def get(self, node_id):
|
||||
"""
|
||||
Retrieve the cached value for a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to retrieve.
|
||||
|
||||
Returns:
|
||||
The cached value for the node.
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the parent node.
|
||||
children_ids: List of child node IDs to associate with the parent node.
|
||||
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
return subcache
|
||||
|
||||
def _cleanup_ancestors(self, node_id):
|
||||
"""
|
||||
Check if ancestors of a node can be removed from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node whose ancestors are to be checked.
|
||||
"""
|
||||
for ancestor_id in self.ancestors.get(node_id, []):
|
||||
if ancestor_id in self.executed_nodes:
|
||||
# Remove ancestor if all its descendants have been executed
|
||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||
self._remove_node(ancestor_id)
|
||||
|
||||
def _remove_node(self, node_id):
|
||||
"""
|
||||
Remove a node from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to remove.
|
||||
"""
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
del self.cache[cache_key]
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
del self.subcaches[subcache_key]
|
||||
|
||||
def clean_unused(self):
|
||||
"""
|
||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
"""
|
||||
Dump the cache and dependency graph for debugging.
|
||||
|
||||
Returns:
|
||||
A list containing the cache state and dependency graph.
|
||||
"""
|
||||
result = super().recursive_debug_dump()
|
||||
result.append({
|
||||
"descendants": self.descendants,
|
||||
"ancestors": self.ancestors,
|
||||
"executed_nodes": list(self.executed_nodes),
|
||||
})
|
||||
return result
|
||||
|
31
execution.py
31
execution.py
@ -15,7 +15,7 @@ import nodes
|
||||
import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.validation import validate_node_input
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@ -60,26 +60,32 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
def __init__(self, lru_size=None, cache_none=False):
|
||||
if cache_none:
|
||||
self.init_dependency_aware_cache()
|
||||
elif lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# only hold cached items while the decendents have not executed
|
||||
def init_dependency_aware_cache(self):
|
||||
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.objects = DependencyAwareCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
@ -414,13 +420,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server, lru_size=None):
|
||||
def __init__(self, server, lru_size=None, cache_none=False):
|
||||
self.lru_size = lru_size
|
||||
self.cache_none = cache_none
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.caches = CacheSet(self.lru_size)
|
||||
self.caches = CacheSet(self.lru_size, self.cache_none)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
|
2
main.py
2
main.py
@ -156,7 +156,7 @@ def cuda_malloc_warning():
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
||||
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru, cache_none=args.cache_none)
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.14.6
|
||||
comfyui-frontend-package==1.15.13
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
response: web.Response = await handler(request)
|
||||
if request.path.endswith('.js') or request.path.endswith('.css'):
|
||||
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||
return response
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user