Make hook_scope functional for TransformerOptionsHook

This commit is contained in:
Jedrzej Kosinski 2025-01-06 02:23:04 -06:00
parent 6463c39ce0
commit f48f90e471
2 changed files with 28 additions and 17 deletions

View File

@ -86,9 +86,9 @@ class Hook:
self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_ref = hook_ref if hook_ref else _HookRef()
self.hook_id = hook_id self.hook_id = hook_id
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
self.hook_scope = hook_scope
self.custom_should_register = default_should_register self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False self.auto_apply_to_nonpositive = False
self.hook_scope = hook_scope
@property @property
def strength(self): def strength(self):
@ -107,6 +107,7 @@ class Hook:
c.hook_ref = self.hook_ref c.hook_ref = self.hook_ref
c.hook_id = self.hook_id c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe c.hook_keyframe = self.hook_keyframe
c.hook_scope = self.hook_scope
c.custom_should_register = self.custom_should_register c.custom_should_register = self.custom_should_register
# TODO: make this do something # TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
@ -118,12 +119,6 @@ class Hook:
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
pass
def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]):
pass
def __eq__(self, other: Hook): def __eq__(self, other: Hook):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
@ -143,6 +138,7 @@ class WeightHook(Hook):
self.need_weight_init = True self.need_weight_init = True
self._strength_model = strength_model self._strength_model = strength_model
self._strength_clip = strength_clip self._strength_clip = strength_clip
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
@property @property
def strength_model(self): def strength_model(self):
@ -190,9 +186,11 @@ class WeightHook(Hook):
return c return c
class ObjectPatchHook(Hook): class ObjectPatchHook(Hook):
def __init__(self, object_patches: dict[str]=None): def __init__(self, object_patches: dict[str]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.ObjectPatch) super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches = object_patches self.object_patches = object_patches
self.hook_scope = hook_scope
def clone(self): def clone(self):
c: ObjectPatchHook = super().clone() c: ObjectPatchHook = super().clone()
@ -216,14 +214,11 @@ class AddModelsHook(Hook):
super().__init__(hook_type=EnumHookType.AddModels) super().__init__(hook_type=EnumHookType.AddModels)
self.models = models self.models = models
self.key = key self.key = key
self.append_when_same = True
'''Curently does nothing.'''
def clone(self): def clone(self):
c: AddModelsHook = super().clone() c: AddModelsHook = super().clone()
c.models = self.models.copy() if self.models else self.models c.models = self.models.copy() if self.models else self.models
c.key = self.key c.key = self.key
c.append_when_same = self.append_when_same
return c return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
@ -236,9 +231,11 @@ class TransformerOptionsHook(Hook):
''' '''
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
''' '''
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.TransformerOptions) super().__init__(hook_type=EnumHookType.TransformerOptions)
self.transformers_dict = wrappers_dict self.transformers_dict = transformers_dict
self.hook_scope = hook_scope
def clone(self): def clone(self):
c: TransformerOptionsHook = super().clone() c: TransformerOptionsHook = super().clone()
@ -254,8 +251,9 @@ class TransformerOptionsHook(Hook):
"to_load_options": self.transformers_dict} "to_load_options": self.transformers_dict}
else: else:
add_model_options = {"to_load_options": self.transformers_dict} add_model_options = {"to_load_options": self.transformers_dict}
# only register if will not be included in AllConditioning to avoid double loading
registered.add(self)
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.add(self)
return True return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
@ -265,10 +263,12 @@ WrapperHook = TransformerOptionsHook
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
class SetInjectionsHook(Hook): class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None): def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.Injections) super().__init__(hook_type=EnumHookType.Injections)
self.key = key self.key = key
self.injections = injections self.injections = injections
self.hook_scope = hook_scope
def clone(self): def clone(self):
c: SetInjectionsHook = super().clone() c: SetInjectionsHook = super().clone()
@ -590,6 +590,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
sorted_list.extend(object_list) sorted_list.extend(object_list)
return sorted_list return sorted_list
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
# if no hooks or is not a ModelPatcher for sampling, return empty dict
if hooks is None or model.is_clip:
return {}
if transformer_options is None:
transformer_options = {}
for hook in hooks.get_type(EnumHookType.TransformerOptions):
hook: TransformerOptionsHook
hook.on_apply_hooks(model, transformer_options)
return transformer_options
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup() hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)

View File

@ -1010,11 +1010,11 @@ class ModelPatcher:
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
# TODO: return transformer_options dict with any additions from hooks # TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
return {} return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
self.patch_hooks(hooks=hooks) self.patch_hooks(hooks=hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks) callback(self, hooks)
return {} return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
def patch_hooks(self, hooks: comfy.hooks.HookGroup): def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected(): with self.use_ejected():