mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Make hook_scope functional for TransformerOptionsHook
This commit is contained in:
parent
6463c39ce0
commit
f48f90e471
@ -86,9 +86,9 @@ class Hook:
|
||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||
self.hook_id = hook_id
|
||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||
self.hook_scope = hook_scope
|
||||
self.custom_should_register = default_should_register
|
||||
self.auto_apply_to_nonpositive = False
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
@ -107,6 +107,7 @@ class Hook:
|
||||
c.hook_ref = self.hook_ref
|
||||
c.hook_id = self.hook_id
|
||||
c.hook_keyframe = self.hook_keyframe
|
||||
c.hook_scope = self.hook_scope
|
||||
c.custom_should_register = self.custom_should_register
|
||||
# TODO: make this do something
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
@ -118,12 +119,6 @@ class Hook:
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def __eq__(self, other: Hook):
|
||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||
|
||||
@ -143,6 +138,7 @@ class WeightHook(Hook):
|
||||
self.need_weight_init = True
|
||||
self._strength_model = strength_model
|
||||
self._strength_clip = strength_clip
|
||||
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
|
||||
|
||||
@property
|
||||
def strength_model(self):
|
||||
@ -190,9 +186,11 @@ class WeightHook(Hook):
|
||||
return c
|
||||
|
||||
class ObjectPatchHook(Hook):
|
||||
def __init__(self, object_patches: dict[str]=None):
|
||||
def __init__(self, object_patches: dict[str]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||
self.object_patches = object_patches
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
def clone(self):
|
||||
c: ObjectPatchHook = super().clone()
|
||||
@ -216,14 +214,11 @@ class AddModelsHook(Hook):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.models = models
|
||||
self.key = key
|
||||
self.append_when_same = True
|
||||
'''Curently does nothing.'''
|
||||
|
||||
def clone(self):
|
||||
c: AddModelsHook = super().clone()
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.key = self.key
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
@ -236,9 +231,11 @@ class TransformerOptionsHook(Hook):
|
||||
'''
|
||||
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||
'''
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||
self.transformers_dict = wrappers_dict
|
||||
self.transformers_dict = transformers_dict
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
def clone(self):
|
||||
c: TransformerOptionsHook = super().clone()
|
||||
@ -254,8 +251,9 @@ class TransformerOptionsHook(Hook):
|
||||
"to_load_options": self.transformers_dict}
|
||||
else:
|
||||
add_model_options = {"to_load_options": self.transformers_dict}
|
||||
# only register if will not be included in AllConditioning to avoid double loading
|
||||
registered.add(self)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.add(self)
|
||||
return True
|
||||
|
||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||
@ -265,10 +263,12 @@ WrapperHook = TransformerOptionsHook
|
||||
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list[PatcherInjection]=None):
|
||||
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.Injections)
|
||||
self.key = key
|
||||
self.injections = injections
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
def clone(self):
|
||||
c: SetInjectionsHook = super().clone()
|
||||
@ -590,6 +590,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||
sorted_list.extend(object_list)
|
||||
return sorted_list
|
||||
|
||||
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
|
||||
# if no hooks or is not a ModelPatcher for sampling, return empty dict
|
||||
if hooks is None or model.is_clip:
|
||||
return {}
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
for hook in hooks.get_type(EnumHookType.TransformerOptions):
|
||||
hook: TransformerOptionsHook
|
||||
hook.on_apply_hooks(model, transformer_options)
|
||||
return transformer_options
|
||||
|
||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
|
@ -1010,11 +1010,11 @@ class ModelPatcher:
|
||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
||||
# TODO: return transformer_options dict with any additions from hooks
|
||||
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||
return {}
|
||||
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||
self.patch_hooks(hooks=hooks)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||
callback(self, hooks)
|
||||
return {}
|
||||
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||
|
||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
with self.use_ejected():
|
||||
|
Loading…
Reference in New Issue
Block a user