diff --git a/comfy/hooks.py b/comfy/hooks.py index dde3e8bcb..cc9f6cd54 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -86,9 +86,9 @@ class Hook: self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_id = hook_id self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + self.hook_scope = hook_scope self.custom_should_register = default_should_register self.auto_apply_to_nonpositive = False - self.hook_scope = hook_scope @property def strength(self): @@ -107,6 +107,7 @@ class Hook: c.hook_ref = self.hook_ref c.hook_id = self.hook_id c.hook_keyframe = self.hook_keyframe + c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register # TODO: make this do something c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive @@ -118,12 +119,6 @@ class Hook: def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): - pass - - def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]): - pass - def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref @@ -143,6 +138,7 @@ class WeightHook(Hook): self.need_weight_init = True self._strength_model = strength_model self._strength_clip = strength_clip + self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs @property def strength_model(self): @@ -190,9 +186,11 @@ class WeightHook(Hook): return c class ObjectPatchHook(Hook): - def __init__(self, object_patches: dict[str]=None): + def __init__(self, object_patches: dict[str]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.ObjectPatch) self.object_patches = object_patches + self.hook_scope = hook_scope def clone(self): c: ObjectPatchHook = super().clone() @@ -216,14 +214,11 @@ class AddModelsHook(Hook): super().__init__(hook_type=EnumHookType.AddModels) self.models = models self.key = key - self.append_when_same = True - '''Curently does nothing.''' def clone(self): c: AddModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models c.key = self.key - c.append_when_same = self.append_when_same return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -236,9 +231,11 @@ class TransformerOptionsHook(Hook): ''' Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. ''' - def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): + def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.TransformerOptions) - self.transformers_dict = wrappers_dict + self.transformers_dict = transformers_dict + self.hook_scope = hook_scope def clone(self): c: TransformerOptionsHook = super().clone() @@ -254,8 +251,9 @@ class TransformerOptionsHook(Hook): "to_load_options": self.transformers_dict} else: add_model_options = {"to_load_options": self.transformers_dict} + # only register if will not be included in AllConditioning to avoid double loading + registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.add(self) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): @@ -265,10 +263,12 @@ WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list[PatcherInjection]=None): + def __init__(self, key: str=None, injections: list[PatcherInjection]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections + self.hook_scope = hook_scope def clone(self): c: SetInjectionsHook = super().clone() @@ -590,6 +590,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list: sorted_list.extend(object_list) return sorted_list +def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None): + # if no hooks or is not a ModelPatcher for sampling, return empty dict + if hooks is None or model.is_clip: + return {} + if transformer_options is None: + transformer_options = {} + for hook in hooks.get_type(EnumHookType.TransformerOptions): + hook: TransformerOptionsHook + hook.on_apply_hooks(model, transformer_options) + return transformer_options + def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): hook_group = HookGroup() hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57a843b8f..51a62e048 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1010,11 +1010,11 @@ class ModelPatcher: def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): # TODO: return transformer_options dict with any additions from hooks if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) self.patch_hooks(hooks=hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): callback(self, hooks) - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) def patch_hooks(self, hooks: comfy.hooks.HookGroup): with self.use_ejected():