diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index e97badd04..23b047342 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -20,11 +20,6 @@ if model_management.xformers_enabled(): import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") -try: - OOM_EXCEPTION = torch.cuda.OutOfMemoryError -except: - OOM_EXCEPTION = Exception - def exists(val): return val is not None @@ -312,7 +307,7 @@ class CrossAttentionDoggettx(nn.Module): r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except OOM_EXCEPTION as e: + except model_management.OOM_EXCEPTION as e: if first_op_done == False: torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index d36fd59c1..94f5510b9 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,11 +13,6 @@ if model_management.xformers_enabled(): import xformers import xformers.ops -try: - OOM_EXCEPTION = torch.cuda.OutOfMemoryError -except: - OOM_EXCEPTION = Exception - def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -221,7 +216,7 @@ class AttnBlock(nn.Module): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except OOM_EXCEPTION as e: + except model_management.OOM_EXCEPTION as e: steps *= 2 if steps > 128: raise e diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index edbff74a2..f3c83f387 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -24,10 +24,7 @@ except ImportError: from torch import Tensor from typing import List -try: - OOM_EXCEPTION = torch.cuda.OutOfMemoryError -except: - OOM_EXCEPTION = Exception +import model_management def dynamic_slice( x: Tensor, @@ -161,7 +158,7 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except OOM_EXCEPTION: + except model_management.OOM_EXCEPTION: print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values torch.exp(attn_scores, out=attn_scores) diff --git a/comfy/model_management.py b/comfy/model_management.py index 5c4e97da3..809b19ea2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,11 @@ try: except: pass +try: + OOM_EXCEPTION = torch.cuda.OutOfMemoryError +except: + OOM_EXCEPTION = Exception + if "--disable-xformers" in sys.argv: XFORMERS_IS_AVAILBLE = False else: diff --git a/comfy/sd.py b/comfy/sd.py index 585419f7e..b344cbece 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -383,12 +383,26 @@ class VAE: device = model_management.get_torch_device() self.device = device - def decode(self, samples): + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): + decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) + output = torch.clamp(( + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) + + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) + + utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8)) + / 3.0) / 2.0, min=0.0, max=1.0) + return output + + def decode(self, samples_in): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) - samples = samples.to(self.device) - pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples) - pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0) + try: + samples = samples_in.to(self.device) + pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples) + pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0) + except model_management.OOM_EXCEPTION as e: + print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + pixel_samples = self.decode_tiled_(samples_in) + self.first_stage_model = self.first_stage_model.cpu() pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples @@ -396,13 +410,7 @@ class VAE: def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) - decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) - output = torch.clamp(( - (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) + - utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) + - utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8)) - / 3.0) / 2.0, min=0.0, max=1.0) - + output = self.decode_tiled_(samples, tile_x, tile_y, overlap) self.first_stage_model = self.first_stage_model.cpu() return output.movedim(1,-1)