mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Try again with vae tiled decoding if regular fails because of OOM.
This commit is contained in:
parent
aae9fe0cf9
commit
3ed4a4e4e6
@ -20,11 +20,6 @@ if model_management.xformers_enabled():
|
|||||||
import os
|
import os
|
||||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||||
|
|
||||||
try:
|
|
||||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
|
||||||
except:
|
|
||||||
OOM_EXCEPTION = Exception
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
@ -312,7 +307,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
del s2
|
del s2
|
||||||
break
|
break
|
||||||
except OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
if first_op_done == False:
|
if first_op_done == False:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
@ -13,11 +13,6 @@ if model_management.xformers_enabled():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
try:
|
|
||||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
|
||||||
except:
|
|
||||||
OOM_EXCEPTION = Exception
|
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
@ -221,7 +216,7 @@ class AttnBlock(nn.Module):
|
|||||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||||
del s2
|
del s2
|
||||||
break
|
break
|
||||||
except OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
steps *= 2
|
steps *= 2
|
||||||
if steps > 128:
|
if steps > 128:
|
||||||
raise e
|
raise e
|
||||||
|
@ -24,10 +24,7 @@ except ImportError:
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
try:
|
import model_management
|
||||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
|
||||||
except:
|
|
||||||
OOM_EXCEPTION = Exception
|
|
||||||
|
|
||||||
def dynamic_slice(
|
def dynamic_slice(
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -161,7 +158,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
try:
|
try:
|
||||||
attn_probs = attn_scores.softmax(dim=-1)
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
del attn_scores
|
del attn_scores
|
||||||
except OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
||||||
torch.exp(attn_scores, out=attn_scores)
|
torch.exp(attn_scores, out=attn_scores)
|
||||||
|
@ -31,6 +31,11 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||||
|
except:
|
||||||
|
OOM_EXCEPTION = Exception
|
||||||
|
|
||||||
if "--disable-xformers" in sys.argv:
|
if "--disable-xformers" in sys.argv:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
XFORMERS_IS_AVAILBLE = False
|
||||||
else:
|
else:
|
||||||
|
30
comfy/sd.py
30
comfy/sd.py
@ -383,12 +383,26 @@ class VAE:
|
|||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def decode(self, samples):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
|
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
|
||||||
|
output = torch.clamp((
|
||||||
|
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
|
||||||
|
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
|
||||||
|
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
|
||||||
|
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def decode(self, samples_in):
|
||||||
model_management.unload_model()
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
samples = samples.to(self.device)
|
try:
|
||||||
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
|
samples = samples_in.to(self.device)
|
||||||
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
|
||||||
|
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.cpu()
|
||||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@ -396,13 +410,7 @@ class VAE:
|
|||||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
model_management.unload_model()
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
|
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||||
output = torch.clamp((
|
|
||||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
|
|
||||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
|
|
||||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
|
|
||||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.cpu()
|
||||||
return output.movedim(1,-1)
|
return output.movedim(1,-1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user