mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Merge remote-tracking branch 'origin' into frontendrefactor
This commit is contained in:
commit
65c432e6ee
@ -442,11 +442,61 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
class CrossAttentionPytorch(nn.Module):
|
||||||
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
context_dim = default(context_dim, query_dim)
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||||
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
|
def forward(self, x, context=None, mask=None):
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
b, _, _ = q.shape
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
raise NotImplementedError
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
if XFORMERS_IS_AVAILBLE == False:
|
if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv:
|
||||||
if "--use-split-cross-attention" in sys.argv:
|
if "--use-split-cross-attention" in sys.argv:
|
||||||
print("Using split optimization for cross attention")
|
print("Using split optimization for cross attention")
|
||||||
CrossAttention = CrossAttentionDoggettx
|
CrossAttention = CrossAttentionDoggettx
|
||||||
|
else:
|
||||||
|
if "--use-pytorch-cross-attention" in sys.argv:
|
||||||
|
print("Using pytorch cross attention")
|
||||||
|
torch.backends.cuda.enable_math_sdp(False)
|
||||||
|
CrossAttention = CrossAttentionPytorch
|
||||||
else:
|
else:
|
||||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
CrossAttention = CrossAttentionBirchSan
|
CrossAttention = CrossAttentionBirchSan
|
||||||
|
@ -135,6 +135,8 @@ class PromptExecutor:
|
|||||||
self.server = server
|
self.server = server
|
||||||
|
|
||||||
def execute(self, prompt, extra_data={}):
|
def execute(self, prompt, extra_data={}):
|
||||||
|
nodes.interrupt_processing(False)
|
||||||
|
|
||||||
if "client_id" in extra_data:
|
if "client_id" in extra_data:
|
||||||
self.server.client_id = extra_data["client_id"]
|
self.server.client_id = extra_data["client_id"]
|
||||||
else:
|
else:
|
||||||
|
4
nodes.py
4
nodes.py
@ -45,8 +45,8 @@ def filter_files_extensions(files, extensions):
|
|||||||
def before_node_execution():
|
def before_node_execution():
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
def interrupt_processing():
|
def interrupt_processing(value=True):
|
||||||
model_management.interrupt_current_processing()
|
model_management.interrupt_current_processing(value)
|
||||||
|
|
||||||
class CLIPTextEncode:
|
class CLIPTextEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user