diff --git a/comfy/hooks.py b/comfy/hooks.py index 7c2f6689..9d073107 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -42,7 +42,7 @@ class EnumHookType(enum.Enum): ''' Weight = "weight" ObjectPatch = "object_patch" - AddModels = "add_models" + AdditionalModels = "add_models" TransformerOptions = "transformer_options" Injections = "add_injections" @@ -202,24 +202,20 @@ class ObjectPatchHook(Hook): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") - if not self.should_register(model, model_options, target_dict, registered): - return False - registered.add(self) - return True -class AddModelsHook(Hook): +class AdditionalModelsHook(Hook): ''' Hook responsible for telling model management any additional models that should be loaded. Note, value of hook_scope is ignored and is treated as AllConditioning. ''' def __init__(self, models: list[ModelPatcher]=None, key: str=None): - super().__init__(hook_type=EnumHookType.AddModels) + super().__init__(hook_type=EnumHookType.AdditionalModels) self.models = models self.key = key def clone(self): - c: AddModelsHook = super().clone() + c: AdditionalModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models c.key = self.key return c @@ -271,7 +267,7 @@ class TransformerOptionsHook(Hook): WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' -class SetInjectionsHook(Hook): +class InjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None, hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.Injections) @@ -280,21 +276,13 @@ class SetInjectionsHook(Hook): self.hook_scope = hook_scope def clone(self): - c: SetInjectionsHook = super().clone() + c: InjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): - raise NotImplementedError("SetInjectionsHook is not supported yet in ComfyUI.") - if not self.should_register(model, model_options, target_dict, registered): - return False - registered.add(self) - return True - - def add_hook_injections(self, model: ModelPatcher): - # TODO: add functionality - pass + raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.") class HookGroup: ''' diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1433d185..b70e5e63 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -94,8 +94,8 @@ def get_additional_models_from_model_options(model_options: dict[str]=None): models = [] if model_options is not None and "registered_hooks" in model_options: registered: comfy.hooks.HookGroup = model_options["registered_hooks"] - for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels): - hook: comfy.hooks.AddModelsHook + for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook models.extend(hook.models) return models @@ -146,8 +146,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): hook: comfy.hooks.TransformerOptionsHook hook.add_hook_patches(model, model_options, target_dict, registered) # handle all AddModelsHooks - for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels): - hook: comfy.hooks.AddModelsHook + for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook hook.add_hook_patches(model, model_options, target_dict, registered) # handle all WeightHooks by registering on ModelPatcher model.register_all_hook_patches(hooks, target_dict, model_options, registered)