Load weights that can't be lowvramed to target device.

This commit is contained in:
comfyanonymous 2023-12-28 21:41:10 -05:00
parent a8baa40d85
commit e1e322cf69

View File

@ -259,6 +259,14 @@ print("VAE dtype:", VAE_DTYPE)
current_loaded_models = []
def module_size(module):
module_mem = 0
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
return module_mem
class LoadedModel:
def __init__(self, model):
self.model = model
@ -296,14 +304,14 @@ class LoadedModel:
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = 0
sd = m.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
module_mem = module_size(m)
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
print("lowvram: loaded module regularly", m)
self.model_accelerated = True