mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-13 14:21:20 +00:00
Load weights that can't be lowvramed to target device.
This commit is contained in:
parent
a8baa40d85
commit
e1e322cf69
@ -259,6 +259,14 @@ print("VAE dtype:", VAE_DTYPE)
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nelement() * t.element_size()
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
@ -296,14 +304,14 @@ class LoadedModel:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
module_mem = 0
|
||||
sd = m.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nelement() * t.element_size()
|
||||
module_mem = module_size(m)
|
||||
if mem_counter + module_mem < lowvram_model_memory:
|
||||
m.to(self.device)
|
||||
mem_counter += module_mem
|
||||
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
|
||||
m.to(self.device)
|
||||
mem_counter += module_size(m)
|
||||
print("lowvram: loaded module regularly", m)
|
||||
|
||||
self.model_accelerated = True
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user