When cached_hook_patches contain weights for hooks, only use hook_backup for unused keys

This commit is contained in:
Jedrzej Kosinski 2025-03-04 02:38:31 -06:00
parent 7c7c70c400
commit 78fbd61caf

View File

@ -1089,7 +1089,6 @@ class ModelPatcher:
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():
self.unpatch_hooks()
if hooks is not None:
model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None
@ -1100,12 +1099,16 @@ class ModelPatcher:
# if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
model_sd_keys_set = set(model_sd_keys)
for key in cached_weights:
if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
model_sd_keys_set.remove(key)
self.unpatch_hooks(model_sd_keys_set)
else:
self.unpatch_hooks()
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None
if len(relevant_patches) > 0:
@ -1116,6 +1119,8 @@ class ModelPatcher:
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
else:
self.unpatch_hooks()
self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
@ -1172,17 +1177,23 @@ class ModelPatcher:
del out_weight
del weight
def unpatch_hooks(self) -> None:
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0:
self.current_hooks = None
return
keys = list(self.hook_backup.keys())
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
if whitelist_keys_set:
for k in keys:
if k in whitelist_keys_set:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.pop(k)
else:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear()
self.current_hooks = None
self.hook_backup.clear()
self.current_hooks = None
def clean_hooks(self):
self.unpatch_hooks()