From 410bf0777197c7005fe13aa4f6717d6dc63e2b22 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 22 Nov 2023 18:16:02 -0500 Subject: [PATCH] Make VAE memory estimation take dtype into account. --- comfy/sd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c006a036..a8df3bdd 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -155,8 +155,8 @@ class VAE: if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) - self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking - self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7 + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) + self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) if config is None: if "taesd_decoder.1.weight" in sd: @@ -213,7 +213,7 @@ class VAE: def decode(self, samples_in): self.first_stage_model = self.first_stage_model.to(self.device) try: - memory_used = self.memory_used_decode(samples_in.shape) + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) @@ -241,7 +241,7 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: - memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used)