Try again with vae tiled decoding if regular fails because of OOM.

This commit is contained in:
comfyanonymous 2023-03-22 14:49:00 -04:00
parent aae9fe0cf9
commit 3ed4a4e4e6
5 changed files with 28 additions and 28 deletions

View File

@ -20,11 +20,6 @@ if model_management.xformers_enabled():
import os import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def exists(val): def exists(val):
return val is not None return val is not None
@ -312,7 +307,7 @@ class CrossAttentionDoggettx(nn.Module):
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2 del s2
break break
except OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
if first_op_done == False: if first_op_done == False:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()

View File

@ -13,11 +13,6 @@ if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
@ -221,7 +216,7 @@ class AttnBlock(nn.Module):
r1[:, :, i:end] = torch.bmm(v, s2) r1[:, :, i:end] = torch.bmm(v, s2)
del s2 del s2
break break
except OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
steps *= 2 steps *= 2
if steps > 128: if steps > 128:
raise e raise e

View File

@ -24,10 +24,7 @@ except ImportError:
from torch import Tensor from torch import Tensor
from typing import List from typing import List
try: import model_management
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def dynamic_slice( def dynamic_slice(
x: Tensor, x: Tensor,
@ -161,7 +158,7 @@ def _get_attention_scores_no_kv_chunking(
try: try:
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
except OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores) torch.exp(attn_scores, out=attn_scores)

View File

@ -31,6 +31,11 @@ try:
except: except:
pass pass
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
if "--disable-xformers" in sys.argv: if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILBLE = False
else: else:

View File

@ -383,12 +383,26 @@ class VAE:
device = model_management.get_torch_device() device = model_management.get_torch_device()
self.device = device self.device = device
def decode(self, samples): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
/ 3.0) / 2.0, min=0.0, max=1.0)
return output
def decode(self, samples_in):
model_management.unload_model() model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
samples = samples.to(self.device) try:
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples) samples = samples_in.to(self.device)
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0) pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
self.first_stage_model = self.first_stage_model.cpu() self.first_stage_model = self.first_stage_model.cpu()
pixel_samples = pixel_samples.cpu().movedim(1,-1) pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples return pixel_samples
@ -396,13 +410,7 @@ class VAE:
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.unload_model() model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
/ 3.0) / 2.0, min=0.0, max=1.0)
self.first_stage_model = self.first_stage_model.cpu() self.first_stage_model = self.first_stage_model.cpu()
return output.movedim(1,-1) return output.movedim(1,-1)