mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Make hook_scope functional for TransformerOptionsHook
This commit is contained in:
parent
6463c39ce0
commit
f48f90e471
@ -86,9 +86,9 @@ class Hook:
|
|||||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||||
self.hook_id = hook_id
|
self.hook_id = hook_id
|
||||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||||
|
self.hook_scope = hook_scope
|
||||||
self.custom_should_register = default_should_register
|
self.custom_should_register = default_should_register
|
||||||
self.auto_apply_to_nonpositive = False
|
self.auto_apply_to_nonpositive = False
|
||||||
self.hook_scope = hook_scope
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength(self):
|
def strength(self):
|
||||||
@ -107,6 +107,7 @@ class Hook:
|
|||||||
c.hook_ref = self.hook_ref
|
c.hook_ref = self.hook_ref
|
||||||
c.hook_id = self.hook_id
|
c.hook_id = self.hook_id
|
||||||
c.hook_keyframe = self.hook_keyframe
|
c.hook_keyframe = self.hook_keyframe
|
||||||
|
c.hook_scope = self.hook_scope
|
||||||
c.custom_should_register = self.custom_should_register
|
c.custom_should_register = self.custom_should_register
|
||||||
# TODO: make this do something
|
# TODO: make this do something
|
||||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||||
@ -118,12 +119,6 @@ class Hook:
|
|||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||||
|
|
||||||
def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __eq__(self, other: Hook):
|
def __eq__(self, other: Hook):
|
||||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||||
|
|
||||||
@ -143,6 +138,7 @@ class WeightHook(Hook):
|
|||||||
self.need_weight_init = True
|
self.need_weight_init = True
|
||||||
self._strength_model = strength_model
|
self._strength_model = strength_model
|
||||||
self._strength_clip = strength_clip
|
self._strength_clip = strength_clip
|
||||||
|
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength_model(self):
|
def strength_model(self):
|
||||||
@ -190,9 +186,11 @@ class WeightHook(Hook):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
class ObjectPatchHook(Hook):
|
class ObjectPatchHook(Hook):
|
||||||
def __init__(self, object_patches: dict[str]=None):
|
def __init__(self, object_patches: dict[str]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||||
self.object_patches = object_patches
|
self.object_patches = object_patches
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c: ObjectPatchHook = super().clone()
|
c: ObjectPatchHook = super().clone()
|
||||||
@ -216,14 +214,11 @@ class AddModelsHook(Hook):
|
|||||||
super().__init__(hook_type=EnumHookType.AddModels)
|
super().__init__(hook_type=EnumHookType.AddModels)
|
||||||
self.models = models
|
self.models = models
|
||||||
self.key = key
|
self.key = key
|
||||||
self.append_when_same = True
|
|
||||||
'''Curently does nothing.'''
|
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c: AddModelsHook = super().clone()
|
c: AddModelsHook = super().clone()
|
||||||
c.models = self.models.copy() if self.models else self.models
|
c.models = self.models.copy() if self.models else self.models
|
||||||
c.key = self.key
|
c.key = self.key
|
||||||
c.append_when_same = self.append_when_same
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
@ -236,9 +231,11 @@ class TransformerOptionsHook(Hook):
|
|||||||
'''
|
'''
|
||||||
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||||
'''
|
'''
|
||||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||||
self.transformers_dict = wrappers_dict
|
self.transformers_dict = transformers_dict
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c: TransformerOptionsHook = super().clone()
|
c: TransformerOptionsHook = super().clone()
|
||||||
@ -254,8 +251,9 @@ class TransformerOptionsHook(Hook):
|
|||||||
"to_load_options": self.transformers_dict}
|
"to_load_options": self.transformers_dict}
|
||||||
else:
|
else:
|
||||||
add_model_options = {"to_load_options": self.transformers_dict}
|
add_model_options = {"to_load_options": self.transformers_dict}
|
||||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
# only register if will not be included in AllConditioning to avoid double loading
|
||||||
registered.add(self)
|
registered.add(self)
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||||
@ -265,10 +263,12 @@ WrapperHook = TransformerOptionsHook
|
|||||||
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
||||||
|
|
||||||
class SetInjectionsHook(Hook):
|
class SetInjectionsHook(Hook):
|
||||||
def __init__(self, key: str=None, injections: list[PatcherInjection]=None):
|
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
super().__init__(hook_type=EnumHookType.Injections)
|
super().__init__(hook_type=EnumHookType.Injections)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.injections = injections
|
self.injections = injections
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c: SetInjectionsHook = super().clone()
|
c: SetInjectionsHook = super().clone()
|
||||||
@ -590,6 +590,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
|||||||
sorted_list.extend(object_list)
|
sorted_list.extend(object_list)
|
||||||
return sorted_list
|
return sorted_list
|
||||||
|
|
||||||
|
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
|
||||||
|
# if no hooks or is not a ModelPatcher for sampling, return empty dict
|
||||||
|
if hooks is None or model.is_clip:
|
||||||
|
return {}
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
|
for hook in hooks.get_type(EnumHookType.TransformerOptions):
|
||||||
|
hook: TransformerOptionsHook
|
||||||
|
hook.on_apply_hooks(model, transformer_options)
|
||||||
|
return transformer_options
|
||||||
|
|
||||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||||
hook_group = HookGroup()
|
hook_group = HookGroup()
|
||||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||||
|
@ -1010,11 +1010,11 @@ class ModelPatcher:
|
|||||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
||||||
# TODO: return transformer_options dict with any additions from hooks
|
# TODO: return transformer_options dict with any additions from hooks
|
||||||
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||||
return {}
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||||
self.patch_hooks(hooks=hooks)
|
self.patch_hooks(hooks=hooks)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||||
callback(self, hooks)
|
callback(self, hooks)
|
||||||
return {}
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
|
Loading…
Reference in New Issue
Block a user