mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Make VAE memory estimation take dtype into account.
This commit is contained in:
parent
32447f0c39
commit
410bf07771
@ -155,8 +155,8 @@ class VAE:
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking
|
||||
self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
|
||||
if config is None:
|
||||
if "taesd_decoder.1.weight" in sd:
|
||||
@ -213,7 +213,7 @@ class VAE:
|
||||
def decode(self, samples_in):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape)
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.free_memory(memory_used, self.device)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
@ -241,7 +241,7 @@ class VAE:
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.free_memory(memory_used, self.device)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
|
Loading…
Reference in New Issue
Block a user