Same thing but for the other places where it's used.

This commit is contained in:
comfyanonymous 2023-02-09 12:43:29 -05:00
parent df40d4f3bf
commit 773cdabfce
2 changed files with 11 additions and 2 deletions

View File

@ -20,6 +20,11 @@ except:
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def exists(val):
return val is not None
@ -316,7 +321,7 @@ class CrossAttentionDoggettx(nn.Module):
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except torch.cuda.OutOfMemoryError as e:
except OOM_EXCEPTION as e:
if first_op_done == False:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@ -16,6 +16,10 @@ except:
XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.")
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def get_timestep_embedding(timesteps, embedding_dim):
"""
@ -229,7 +233,7 @@ class AttnBlock(nn.Module):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except torch.cuda.OutOfMemoryError as e:
except OOM_EXCEPTION as e:
if first_op_done == False:
steps *= 2
if steps > 128: