diff --git a/comfy/sd.py b/comfy/sd.py index 718dccd0..d898d019 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -589,11 +589,14 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: - batch_number = 1 + free_memory = model_management.get_free_memory(self.device) + batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + batch_number = max(1, batch_number) samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device) samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor + except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples)