diff --git a/comfy/hooks.py b/comfy/hooks.py index b62092cc..dde3e8bc 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -317,6 +317,18 @@ class HookGroup: def contains(self, hook: Hook): return hook in self.hooks + def is_subset_of(self, other: HookGroup): + self_hooks = set(self.hooks) + other_hooks = set(other.hooks) + return self_hooks.issubset(other_hooks) + + def new_with_common_hooks(self, other: HookGroup): + c = HookGroup() + for hook in self.hooks: + if other.contains(hook): + c.add(hook.clone()) + return c + def clone(self): c = HookGroup() for hook in self.hooks: @@ -668,24 +680,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H else: c_dict[hooks_key] = cache[hooks_tuple] -def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): +def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True, + cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): c = [] - hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} + if cache is None: + cache = {} for t in conditioning: n = [t[0], t[1].copy()] for k in values: if append_hooks and k == 'hooks': - _combine_hooks_from_values(n[1], values, hooks_combine_cache) + _combine_hooks_from_values(n[1], values, cache) else: n[1][k] = values[k] c.append(n) return c -def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): +def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): if hooks is None: return cond - return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) + return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache) def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): if timestep_range is None: @@ -720,9 +734,10 @@ def combine_with_new_conds(conds: list, new_conds: list): def set_conds_props(conds: list, strength: float, set_cond_area: str, mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): final_conds = [] + cache = {} for c in conds: # first, apply lora_hook to conditioning, if provided - c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) + c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to conditioning c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) # apply timesteps, if present @@ -734,9 +749,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str, def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, masked_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) + masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to new conditioning, if provided masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) # apply timesteps, if present @@ -748,9 +764,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1. def set_default_conds_and_combine(conds: list, new_conds: list, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, new_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) + new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache) # next, add default_cond key to cond so that during sampling, it can be identified new_c = conditioning_set_values(new_c, {'default': True}) # apply timesteps, if present diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2a551087..57a843b8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -210,7 +210,7 @@ class ModelPatcher: self.injections: dict[str, list[PatcherInjection]] = {} self.hook_patches: dict[comfy.hooks._HookRef] = {} - self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} + self.hook_patches_backup: dict[comfy.hooks._HookRef] = None self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} self.current_hooks: Optional[comfy.hooks.HookGroup] = None @@ -282,7 +282,7 @@ class ModelPatcher: n.injections[k] = i.copy() # hooks n.hook_patches = create_hook_patches_clone(self.hook_patches) - n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) + n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup for group in self.cached_hook_patches: n.cached_hook_patches[group] = {} for k in self.cached_hook_patches[group]: @@ -912,9 +912,9 @@ class ModelPatcher: callback(self, timestep) def restore_hook_patches(self): - if len(self.hook_patches_backup) > 0: + if self.hook_patches_backup is not None: self.hook_patches = self.hook_patches_backup - self.hook_patches_backup = {} + self.hook_patches_backup = None def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode @@ -950,6 +950,8 @@ class ModelPatcher: for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) + else: + registered.add(hook) if len(weight_hooks_to_register) > 0: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)