mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Memory check before inference to avoid VAE Decode using exceeded VRAM.
Check if free memory is not less than expected before doing actual decoding, and if it fails, switch to tiled VAE decoding directly. It seems PyTorch may continue occupying memory until the model is destroyed after OOM occurs. This commit tries to avoid OOM from happening in the first place for VAE Decode. This is for VAE Decode ran with exceeded VRAM from #5737.
This commit is contained in:
parent
3d802710e7
commit
a3b9b3c1c3
12
comfy/sd.py
12
comfy/sd.py
@ -348,11 +348,19 @@ class VAE:
|
|||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
|
predicted_oom = False
|
||||||
|
samples = None
|
||||||
|
out = None
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
|
logging.debug(f"Free memory: {free_memory} bytes, predicted memory useage of one batch: {memory_used} bytes")
|
||||||
|
if free_memory < memory_used:
|
||||||
|
logging.warning("Warning: Out of memory is predicted for regular VAE decoding, directly switch to tiled VAE decoding.")
|
||||||
|
predicted_oom = True
|
||||||
|
raise model_management.OOM_EXCEPTION
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
@ -363,6 +371,10 @@ class VAE:
|
|||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
samples = None
|
||||||
|
out = None
|
||||||
|
pixel_samples = None
|
||||||
|
if not predicted_oom:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
dims = samples_in.ndim - 2
|
dims = samples_in.ndim - 2
|
||||||
if dims == 1:
|
if dims == 1:
|
||||||
|
Loading…
Reference in New Issue
Block a user