diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 5ee7d5ae..00a20782 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -442,14 +442,64 @@ class MemoryEfficientCrossAttention(nn.Module): ) return self.to_out(out) +class CrossAttentionPytorch(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + + return self.to_out(out) + import sys -if XFORMERS_IS_AVAILBLE == False: +if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: if "--use-split-cross-attention" in sys.argv: print("Using split optimization for cross attention") CrossAttention = CrossAttentionDoggettx else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") - CrossAttention = CrossAttentionBirchSan + if "--use-pytorch-cross-attention" in sys.argv: + print("Using pytorch cross attention") + torch.backends.cuda.enable_math_sdp(False) + CrossAttention = CrossAttentionPytorch + else: + print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + CrossAttention = CrossAttentionBirchSan else: print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention diff --git a/execution.py b/execution.py index b0f4f952..43cab207 100644 --- a/execution.py +++ b/execution.py @@ -135,6 +135,8 @@ class PromptExecutor: self.server = server def execute(self, prompt, extra_data={}): + nodes.interrupt_processing(False) + if "client_id" in extra_data: self.server.client_id = extra_data["client_id"] else: diff --git a/nodes.py b/nodes.py index fe24a6cd..02cb7e8e 100644 --- a/nodes.py +++ b/nodes.py @@ -45,8 +45,8 @@ def filter_files_extensions(files, extensions): def before_node_execution(): model_management.throw_exception_if_processing_interrupted() -def interrupt_processing(): - model_management.interrupt_current_processing() +def interrupt_processing(value=True): + model_management.interrupt_current_processing(value) class CLIPTextEncode: @classmethod