From c8ce599a8f8ce15a05e5084f0d91b8153250d4e7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 2 Mar 2023 15:24:51 -0500 Subject: [PATCH 1/2] Add a button to interrupt processing to the ui. --- execution.py | 2 ++ nodes.py | 4 ++-- webshit/index.html | 6 ++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index b0f4f952..43cab207 100644 --- a/execution.py +++ b/execution.py @@ -135,6 +135,8 @@ class PromptExecutor: self.server = server def execute(self, prompt, extra_data={}): + nodes.interrupt_processing(False) + if "client_id" in extra_data: self.server.client_id = extra_data["client_id"] else: diff --git a/nodes.py b/nodes.py index fe24a6cd..02cb7e8e 100644 --- a/nodes.py +++ b/nodes.py @@ -45,8 +45,8 @@ def filter_files_extensions(files, extensions): def before_node_execution(): model_management.throw_exception_if_processing_interrupted() -def interrupt_processing(): - model_management.interrupt_current_processing() +def interrupt_processing(value=True): + model_management.interrupt_current_processing(value) class CLIPTextEncode: @classmethod diff --git a/webshit/index.html b/webshit/index.html index 410fb744..a4f3667b 100644 --- a/webshit/index.html +++ b/webshit/index.html @@ -796,6 +796,11 @@ function setRunningNode(id) { (() => { function updateStatus(data) { document.getElementById("queuesize").innerHTML = "Queue size: " + (data ? data.exec_info.queue_remaining : "ERR"); + if (data && data.exec_info.queue_remaining) { + document.getElementById("cancelcurrentjobbutton").hidden = false; + } else { + document.getElementById("cancelcurrentjobbutton").hidden = true; + } } //fix for colab and other things that don't support websockets. @@ -1083,6 +1088,7 @@ function clearItems(type) {
+
Queued:
1 From 1a612e1c74ecb845350bbeab7554992e9f2c175c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 2 Mar 2023 17:01:20 -0500 Subject: [PATCH 2/2] Add some pytorch scaled_dot_product_attention code for testing. --use-pytorch-cross-attention to use it. --- comfy/ldm/modules/attention.py | 56 ++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 5ee7d5ae..00a20782 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -442,14 +442,64 @@ class MemoryEfficientCrossAttention(nn.Module): ) return self.to_out(out) +class CrossAttentionPytorch(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + + return self.to_out(out) + import sys -if XFORMERS_IS_AVAILBLE == False: +if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: if "--use-split-cross-attention" in sys.argv: print("Using split optimization for cross attention") CrossAttention = CrossAttentionDoggettx else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") - CrossAttention = CrossAttentionBirchSan + if "--use-pytorch-cross-attention" in sys.argv: + print("Using pytorch cross attention") + torch.backends.cuda.enable_math_sdp(False) + CrossAttention = CrossAttentionPytorch + else: + print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + CrossAttention = CrossAttentionBirchSan else: print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention