Make VAE memory estimation take dtype into account.

This commit is contained in:
comfyanonymous 2023-11-22 18:16:02 -05:00
parent 32447f0c39
commit 410bf07771

View File

@ -155,8 +155,8 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7 self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if config is None: if config is None:
if "taesd_decoder.1.weight" in sd: if "taesd_decoder.1.weight" in sd:
@ -213,7 +213,7 @@ class VAE:
def decode(self, samples_in): def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
try: try:
memory_used = self.memory_used_decode(samples_in.shape) memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device) model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
@ -241,7 +241,7 @@ class VAE:
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
try: try:
memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device) model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)