Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time)

This commit is contained in:
Jedrzej Kosinski 2025-01-06 01:03:59 -06:00
parent 4446c86052
commit 03a97b604a
2 changed files with 31 additions and 12 deletions

View File

@ -317,6 +317,18 @@ class HookGroup:
def contains(self, hook: Hook): def contains(self, hook: Hook):
return hook in self.hooks return hook in self.hooks
def is_subset_of(self, other: HookGroup):
self_hooks = set(self.hooks)
other_hooks = set(other.hooks)
return self_hooks.issubset(other_hooks)
def new_with_common_hooks(self, other: HookGroup):
c = HookGroup()
for hook in self.hooks:
if other.contains(hook):
c.add(hook.clone())
return c
def clone(self): def clone(self):
c = HookGroup() c = HookGroup()
for hook in self.hooks: for hook in self.hooks:
@ -668,24 +680,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
else: else:
c_dict[hooks_key] = cache[hooks_tuple] c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
c = [] c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} if cache is None:
cache = {}
for t in conditioning: for t in conditioning:
n = [t[0], t[1].copy()] n = [t[0], t[1].copy()]
for k in values: for k in values:
if append_hooks and k == 'hooks': if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache) _combine_hooks_from_values(n[1], values, cache)
else: else:
n[1][k] = values[k] n[1][k] = values[k]
c.append(n) c.append(n)
return c return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
if hooks is None: if hooks is None:
return cond return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None: if timestep_range is None:
@ -720,9 +734,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
def set_conds_props(conds: list, strength: float, set_cond_area: str, def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = [] final_conds = []
cache = {}
for c in conds: for c in conds:
# first, apply lora_hook to conditioning, if provided # first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to conditioning # next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present # apply timesteps, if present
@ -734,9 +749,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = [] combined_conds = []
cache = {}
for c, masked_c in zip(conds, new_conds): for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided # first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to new conditioning, if provided # next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present # apply timesteps, if present
@ -748,9 +764,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
def set_default_conds_and_combine(conds: list, new_conds: list, def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = [] combined_conds = []
cache = {}
for c, new_c in zip(conds, new_conds): for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided # first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
# next, add default_cond key to cond so that during sampling, it can be identified # next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True}) new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present # apply timesteps, if present

View File

@ -210,7 +210,7 @@ class ModelPatcher:
self.injections: dict[str, list[PatcherInjection]] = {} self.injections: dict[str, list[PatcherInjection]] = {}
self.hook_patches: dict[comfy.hooks._HookRef] = {} self.hook_patches: dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
self.current_hooks: Optional[comfy.hooks.HookGroup] = None self.current_hooks: Optional[comfy.hooks.HookGroup] = None
@ -282,7 +282,7 @@ class ModelPatcher:
n.injections[k] = i.copy() n.injections[k] = i.copy()
# hooks # hooks
n.hook_patches = create_hook_patches_clone(self.hook_patches) n.hook_patches = create_hook_patches_clone(self.hook_patches)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
for group in self.cached_hook_patches: for group in self.cached_hook_patches:
n.cached_hook_patches[group] = {} n.cached_hook_patches[group] = {}
for k in self.cached_hook_patches[group]: for k in self.cached_hook_patches[group]:
@ -912,9 +912,9 @@ class ModelPatcher:
callback(self, timestep) callback(self, timestep)
def restore_hook_patches(self): def restore_hook_patches(self):
if len(self.hook_patches_backup) > 0: if self.hook_patches_backup is not None:
self.hook_patches = self.hook_patches_backup self.hook_patches = self.hook_patches_backup
self.hook_patches_backup = {} self.hook_patches_backup = None
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode self.hook_mode = hook_mode
@ -950,6 +950,8 @@ class ModelPatcher:
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
if hook.hook_ref not in self.hook_patches: if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook) weight_hooks_to_register.append(hook)
else:
registered.add(hook)
if len(weight_hooks_to_register) > 0: if len(weight_hooks_to_register) > 0:
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)