Use basic attention implementation for small inputs on old pytorch.

This commit is contained in:
comfyanonymous 2024-01-09 13:46:52 -05:00
parent b3b5ddb07a
commit 6a7bc35db8

View File

@ -351,8 +351,11 @@ else:
optimized_attention_masked = optimized_attention
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input and model_management.pytorch_attention_enabled():
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
if small_input:
if model_management.pytorch_attention_enabled():
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
else:
return attention_basic
if device == torch.device("cpu"):
return attention_sub_quad