mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Update optimized_attention_for_device function for new functions that
support masked attention.
This commit is contained in:
parent
aaa9017302
commit
c6951548cf
@ -57,7 +57,7 @@ class CLIPEncoder(torch.nn.Module):
|
||||
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
|
@ -333,7 +333,6 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
optimized_attention_masked = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
@ -349,15 +348,15 @@ else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
if model_management.pytorch_attention_enabled():
|
||||
optimized_attention_masked = attention_pytorch
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
if small_input and model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return attention_sub_quad
|
||||
|
||||
def optimized_attention_for_device(device, mask=False):
|
||||
if device == torch.device("cpu"): #TODO
|
||||
if model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch
|
||||
else:
|
||||
return attention_basic
|
||||
if mask:
|
||||
return optimized_attention_masked
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user