pytorch xpu should be flash or mem efficient attention?

This commit is contained in:
comfyanonymous 2024-06-04 17:44:14 -04:00
parent 20447e9ec9
commit b1fd26fe9e

View File

@ -693,6 +693,8 @@ def pytorch_attention_flash_attention():
#TODO: more reliable way of checking for flash attention?
if is_nvidia(): #pytorch flash attention only works on Nvidia
return True
if is_intel_xpu():
return True
return False
def force_upcast_attention_dtype():