From adea2beb5c6d0dd1c96d4a7c3de0d391c2c94220 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 11 Jan 2025 02:18:42 -0500 Subject: [PATCH 1/6] Add edm option to ModelSamplingContinuousEDM for Cosmos. You can now use this node with "edm" selected to control the sigma_max and sigma_min of the Cosmos model sampling. --- comfy_extras/nodes_model_advanced.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 285dbf53..33f3face 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],), + "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -206,6 +206,9 @@ class ModelSamplingContinuousEDM: sigma_data = 1.0 if sampling == "eps": sampling_type = comfy.model_sampling.EPS + elif sampling == "edm": + sampling_type = comfy.model_sampling.EDM + sigma_data = 0.5 elif sampling == "v_prediction": sampling_type = comfy.model_sampling.V_PREDICTION elif sampling == "edm_playground_v2.5": From 9c773a241b446c7abb0cfabe1d292746251fed73 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Sat, 11 Jan 2025 03:09:25 -0500 Subject: [PATCH 2/6] Add pyproject.toml (#6386) * Add pyproject.toml * doc * Static version file * Add github action to sync version.py * Change trigger to PR * Fix commit * Grant pr write permission * nit * nit * Don't run on fork PRs * Rename version.py to comfyui_version.py --- .github/workflows/update-version.yml | 58 ++++++++++++++++++++++++++++ comfyui_version.py | 3 ++ pyproject.toml | 11 ++++++ server.py | 18 +-------- 4 files changed, 74 insertions(+), 16 deletions(-) create mode 100644 .github/workflows/update-version.yml create mode 100644 comfyui_version.py create mode 100644 pyproject.toml diff --git a/.github/workflows/update-version.yml b/.github/workflows/update-version.yml new file mode 100644 index 00000000..d9d48897 --- /dev/null +++ b/.github/workflows/update-version.yml @@ -0,0 +1,58 @@ +name: Update Version File + +on: + pull_request: + paths: + - "pyproject.toml" + branches: + - master + +jobs: + update-version: + runs-on: ubuntu-latest + # Don't run on fork PRs + if: github.event.pull_request.head.repo.full_name == github.repository + permissions: + pull-requests: write + contents: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + + - name: Update comfyui_version.py + run: | + # Read version from pyproject.toml and update comfyui_version.py + python -c ' + import tomllib + + # Read version from pyproject.toml + with open("pyproject.toml", "rb") as f: + config = tomllib.load(f) + version = config["project"]["version"] + + # Write version to comfyui_version.py + with open("comfyui_version.py", "w") as f: + f.write("# This file is automatically generated by the build process when version is\n") + f.write("# updated in pyproject.toml.\n") + f.write(f"__version__ = \"{version}\"\n") + ' + + - name: Commit changes + run: | + git config --local user.name "github-actions" + git config --local user.email "github-actions@github.com" + git fetch origin ${{ github.head_ref }} + git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }} + git add comfyui_version.py + git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml" + git push origin HEAD:${{ github.head_ref }} diff --git a/comfyui_version.py b/comfyui_version.py new file mode 100644 index 00000000..7cccc753 --- /dev/null +++ b/comfyui_version.py @@ -0,0 +1,3 @@ +# This file is automatically generated by the build process when version is +# updated in pyproject.toml. +__version__ = "0.3.10" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..1d9d7b3f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,11 @@ +[project] +name = "ComfyUI" +version = "0.3.10" +readme = "README.md" +license = { file = "LICENSE" } +requires-python = ">=3.9" + +[project.urls] +homepage = "https://www.comfy.org/" +repository = "https://github.com/comfyanonymous/ComfyUI" +documentation = "https://docs.comfy.org/" diff --git a/server.py b/server.py index ceb5e83b..bae898ef 100644 --- a/server.py +++ b/server.py @@ -27,6 +27,7 @@ from comfy.cli_args import args import comfy.utils import comfy.model_management import node_helpers +from comfyui_version import __version__ from app.frontend_management import FrontendManager from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -44,21 +45,6 @@ async def send_socket_catch_exception(function, message): except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err: logging.warning("send error: {}".format(err)) -def get_comfyui_version(): - comfyui_version = "unknown" - repo_path = os.path.dirname(os.path.realpath(__file__)) - try: - import pygit2 - repo = pygit2.Repository(repo_path) - comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS) - except Exception: - try: - import subprocess - comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8') - except Exception as e: - logging.warning(f"Failed to get ComfyUI version: {e}") - return comfyui_version.strip() - @web.middleware async def cache_control(request: web.Request, handler): response: web.Response = await handler(request) @@ -518,7 +504,7 @@ class PromptServer(): "os": os.name, "ram_total": ram_total, "ram_free": ram_free, - "comfyui_version": get_comfyui_version(), + "comfyui_version": __version__, "python_version": sys.version, "pytorch_version": comfy.model_management.torch_version, "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", From ee8a7ab69d28b86285acd1b3779dd533e5b8cf6d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 11 Jan 2025 04:40:45 -0500 Subject: [PATCH 3/6] Fast latent preview for Cosmos. --- comfy/latent_formats.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 9e6dfc17..e98982c9 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -386,3 +386,24 @@ class HunyuanVideo(LatentFormat): class Cosmos1CV8x8x8(LatentFormat): latent_channels = 16 latent_dimensions = 3 + + latent_rgb_factors = [ + [ 0.1817, 0.2284, 0.2423], + [-0.0586, -0.0862, -0.3108], + [-0.4703, -0.4255, -0.3995], + [ 0.0803, 0.1963, 0.1001], + [-0.0820, -0.1050, 0.0400], + [ 0.2511, 0.3098, 0.2787], + [-0.1830, -0.2117, -0.0040], + [-0.0621, -0.2187, -0.0939], + [ 0.3619, 0.1082, 0.1455], + [ 0.3164, 0.3922, 0.2575], + [ 0.1152, 0.0231, -0.0462], + [-0.1434, -0.3609, -0.3665], + [ 0.0635, 0.1471, 0.1680], + [-0.3635, -0.1963, -0.3248], + [-0.1865, 0.0365, 0.2346], + [ 0.0447, 0.0994, 0.0881] + ] + + latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976] From 6c9bd11fa3445ed78381cea8ad48ffc95ea49d72 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 11 Jan 2025 11:20:23 -0600 Subject: [PATCH 4/6] Hooks Part 2 - TransformerOptionsHook and AdditionalModelsHook (#6377) * Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' * Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed * Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) * Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type * In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch * Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time * Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational * Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) * Filter only registered hooks on self.conds in CFGGuider.sample * Make hook_scope functional for TransformerOptionsHook * removed 4 whitespace lines to satisfy Ruff, * Add a get_injections function to ModelPatcher * Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable * Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) * Clean up a typehint --- comfy/hooks.py | 395 ++++++++++++++++++++++-------------- comfy/model_patcher.py | 34 ++-- comfy/sampler_helpers.py | 60 ++++-- comfy/samplers.py | 85 +++++++- comfy_extras/nodes_hooks.py | 4 +- 5 files changed, 385 insertions(+), 193 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 3cb0f396..9d073107 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -16,91 +16,132 @@ import comfy.model_management import comfy.patcher_extension from node_helpers import conditioning_set_values +# ####################################################################################################### +# Hooks explanation +# ------------------- +# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to +# make explicit special cases like it does for ControlNet and GLIGEN. +# +# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those +# that should run special code when a 'marked' cond is used in sampling. +# ####################################################################################################### + class EnumHookMode(enum.Enum): + ''' + Priority of hook memory optimization vs. speed, mostly related to WeightHooks. + + MinVram: No caching will occur for any operations related to hooks. + MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups. + ''' MinVram = "minvram" MaxSpeed = "maxspeed" class EnumHookType(enum.Enum): + ''' + Hook types, each of which has different expected behavior. + ''' Weight = "weight" - Patch = "patch" ObjectPatch = "object_patch" - AddModels = "add_models" - Callbacks = "callbacks" - Wrappers = "wrappers" - SetInjections = "add_injections" + AdditionalModels = "add_models" + TransformerOptions = "transformer_options" + Injections = "add_injections" class EnumWeightTarget(enum.Enum): Model = "model" Clip = "clip" +class EnumHookScope(enum.Enum): + ''' + Determines if hook should be limited in its influence over sampling. + + AllConditioning: hook will affect all conds used in sampling. + HookedOnly: hook will only affect the conds it was attached to. + ''' + AllConditioning = "all_conditioning" + HookedOnly = "hooked_only" + + class _HookRef: pass -# NOTE: this is an example of how the should_register function should look -def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + '''Example for how custom_should_register function can look like.''' return True +def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]: + '''Creates base dictionary for use with Hooks' target param.''' + d = {} + if target is not None: + d['target'] = target + d.update(kwargs) + return d + + class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, - hook_keyframe: 'HookKeyframeGroup'=None): + hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type + '''Enum identifying the general class of this hook.''' self.hook_ref = hook_ref if hook_ref else _HookRef() + '''Reference shared between hook clones that have the same value. Should NOT be modified.''' self.hook_id = hook_id + '''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.''' self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + '''Keyframe storage that can be referenced to get strength for current sampling step.''' + self.hook_scope = hook_scope + '''Scope of where this hook should apply in terms of the conds used in sampling run.''' self.custom_should_register = default_should_register - self.auto_apply_to_nonpositive = False + '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register''' @property def strength(self): return self.hook_keyframe.strength - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): self.reset() self.hook_keyframe.initialize_timesteps(model) def reset(self): self.hook_keyframe.reset() - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: Hook = subtype() + def clone(self): + c: Hook = self.__class__() c.hook_type = self.hook_type c.hook_ref = self.hook_ref c.hook_id = self.hook_id c.hook_keyframe = self.hook_keyframe + c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register - # TODO: make this do something - c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - return self.custom_should_register(self, model, model_options, target, registered) + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): - pass - - def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]): - pass - - def __eq__(self, other: 'Hook'): + def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref def __hash__(self): return hash(self.hook_ref) class WeightHook(Hook): + ''' + Hook responsible for tracking weights to be applied to some model/clip. + + Note, value of hook_scope is ignored and is treated as HookedOnly. + ''' def __init__(self, strength_model=1.0, strength_clip=1.0): - super().__init__(hook_type=EnumHookType.Weight) + super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly) self.weights: dict = None self.weights_clip: dict = None self.need_weight_init = True self._strength_model = strength_model self._strength_clip = strength_clip + self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs @property def strength_model(self): @@ -110,36 +151,36 @@ class WeightHook(Hook): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + if not self.should_register(model, model_options, target_dict, registered): return False weights = None - if target == EnumWeightTarget.Model: - strength = self._strength_model - else: + + target = target_dict.get('target', None) + if target == EnumWeightTarget.Clip: strength = self._strength_clip + else: + strength = self._strength_model if self.need_weight_init: key_map = {} - if target == EnumWeightTarget.Model: - key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) - else: + if target == EnumWeightTarget.Clip: key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) + else: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: - if target == EnumWeightTarget.Model: - weights = self.weights - else: + if target == EnumWeightTarget.Clip: weights = self.weights_clip + else: + weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) - registered.append(self) + registered.add(self) return True # TODO: add logs about any keys that were not applied - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WeightHook = super().clone(subtype) + def clone(self): + c: WeightHook = super().clone() c.weights = self.weights c.weights_clip = self.weights_clip c.need_weight_init = self.need_weight_init @@ -147,127 +188,158 @@ class WeightHook(Hook): c._strength_clip = self._strength_clip return c -class PatchHook(Hook): - def __init__(self): - super().__init__(hook_type=EnumHookType.Patch) - self.patches: dict = None - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: PatchHook = super().clone(subtype) - c.patches = self.patches - return c - # TODO: add functionality - class ObjectPatchHook(Hook): - def __init__(self): + def __init__(self, object_patches: dict[str]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.ObjectPatch) - self.object_patches: dict = None + self.object_patches = object_patches + self.hook_scope = hook_scope - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: ObjectPatchHook = super().clone(subtype) + def clone(self): + c: ObjectPatchHook = super().clone() c.object_patches = self.object_patches return c - # TODO: add functionality -class AddModelsHook(Hook): - def __init__(self, key: str=None, models: list['ModelPatcher']=None): - super().__init__(hook_type=EnumHookType.AddModels) - self.key = key + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") + +class AdditionalModelsHook(Hook): + ''' + Hook responsible for telling model management any additional models that should be loaded. + + Note, value of hook_scope is ignored and is treated as AllConditioning. + ''' + def __init__(self, models: list[ModelPatcher]=None, key: str=None): + super().__init__(hook_type=EnumHookType.AdditionalModels) self.models = models - self.append_when_same = True - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: AddModelsHook = super().clone(subtype) - c.key = self.key - c.models = self.models.copy() if self.models else self.models - c.append_when_same = self.append_when_same - return c - # TODO: add functionality - -class CallbackHook(Hook): - def __init__(self, key: str=None, callback: Callable=None): - super().__init__(hook_type=EnumHookType.Callbacks) self.key = key - self.callback = callback - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: CallbackHook = super().clone(subtype) + def clone(self): + c: AdditionalModelsHook = super().clone() + c.models = self.models.copy() if self.models else self.models c.key = self.key - c.callback = self.callback - return c - # TODO: add functionality - -class WrapperHook(Hook): - def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): - super().__init__(hook_type=EnumHookType.Wrappers) - self.wrappers_dict = wrappers_dict - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WrapperHook = super().clone(subtype) - c.wrappers_dict = self.wrappers_dict return c - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.wrappers_dict} - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.append(self) + registered.add(self) return True -class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list['PatcherInjection']=None): - super().__init__(hook_type=EnumHookType.SetInjections) +class TransformerOptionsHook(Hook): + ''' + Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. + ''' + def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None, + hook_scope=EnumHookScope.AllConditioning): + super().__init__(hook_type=EnumHookType.TransformerOptions) + self.transformers_dict = transformers_dict + self.hook_scope = hook_scope + self._skip_adding = False + '''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.''' + + def clone(self): + c: TransformerOptionsHook = super().clone() + c.transformers_dict = self.transformers_dict + c._skip_adding = self._skip_adding + return c + + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + if not self.should_register(model, model_options, target_dict, registered): + return False + # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks + self._skip_adding = False + if self.hook_scope == EnumHookScope.AllConditioning: + add_model_options = {"transformer_options": self.transformers_dict, + "to_load_options": self.transformers_dict} + # skip_adding if included in AllConditioning to avoid double loading + self._skip_adding = True + else: + add_model_options = {"to_load_options": self.transformers_dict} + registered.add(self) + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + return True + + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): + if not self._skip_adding: + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + +WrapperHook = TransformerOptionsHook +'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' + +class InjectionsHook(Hook): + def __init__(self, key: str=None, injections: list[PatcherInjection]=None, + hook_scope=EnumHookScope.AllConditioning): + super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections + self.hook_scope = hook_scope - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: SetInjectionsHook = super().clone(subtype) + def clone(self): + c: InjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c - def add_hook_injections(self, model: 'ModelPatcher'): - # TODO: add functionality - pass + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.") class HookGroup: + ''' + Stores groups of hooks, and allows them to be queried by type. + + To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly; + always use the provided functions on HookGroup. + ''' def __init__(self): self.hooks: list[Hook] = [] + self._hook_dict: dict[EnumHookType, list[Hook]] = {} + + def __len__(self): + return len(self.hooks) def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) + self._hook_dict.setdefault(hook.hook_type, []).append(hook) + + def remove(self, hook: Hook): + if hook in self.hooks: + self.hooks.remove(hook) + self._hook_dict[hook.hook_type].remove(hook) + + def get_type(self, hook_type: EnumHookType): + return self._hook_dict.get(hook_type, []) def contains(self, hook: Hook): return hook in self.hooks + def is_subset_of(self, other: HookGroup): + self_hooks = set(self.hooks) + other_hooks = set(other.hooks) + return self_hooks.issubset(other_hooks) + + def new_with_common_hooks(self, other: HookGroup): + c = HookGroup() + for hook in self.hooks: + if other.contains(hook): + c.add(hook.clone()) + return c + def clone(self): c = HookGroup() for hook in self.hooks: c.add(hook.clone()) return c - def clone_and_combine(self, other: 'HookGroup'): + def clone_and_combine(self, other: HookGroup): c = self.clone() if other is not None: for hook in other.hooks: c.add(hook.clone()) return c - def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): + def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup): if hook_kf is None: hook_kf = HookKeyframeGroup() else: @@ -275,36 +347,29 @@ class HookGroup: for hook in self.hooks: hook.hook_keyframe = hook_kf - def get_dict_repr(self): - d: dict[EnumHookType, dict[Hook, None]] = {} - for hook in self.hooks: - with_type = d.setdefault(hook.hook_type, {}) - with_type[hook] = None - return d - def get_hooks_for_clip_schedule(self): scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} - for hook in self.hooks: - # only care about WeightHooks, for now - if hook.hook_type == EnumHookType.Weight: - hook_schedule = [] - # if no hook keyframes, assign default value - if len(hook.hook_keyframe.keyframes) == 0: - hook_schedule.append(((0.0, 1.0), None)) - scheduled_hooks[hook] = hook_schedule - continue - # find ranges of values - prev_keyframe = hook.hook_keyframe.keyframes[0] - for keyframe in hook.hook_keyframe.keyframes: - if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): - hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) - prev_keyframe = keyframe - elif keyframe.start_percent == prev_keyframe.start_percent: - prev_keyframe = keyframe - # create final range, assuming last start_percent was not 1.0 - if not math.isclose(prev_keyframe.start_percent, 1.0): - hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + # only care about WeightHooks, for now + for hook in self.get_type(EnumHookType.Weight): + hook: WeightHook + hook_schedule = [] + # if no hook keyframes, assign default value + if len(hook.hook_keyframe.keyframes) == 0: + hook_schedule.append(((0.0, 1.0), None)) scheduled_hooks[hook] = hook_schedule + continue + # find ranges of values + prev_keyframe = hook.hook_keyframe.keyframes[0] + for keyframe in hook.hook_keyframe.keyframes: + if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): + hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) + prev_keyframe = keyframe + elif keyframe.start_percent == prev_keyframe.start_percent: + prev_keyframe = keyframe + # create final range, assuming last start_percent was not 1.0 + if not math.isclose(prev_keyframe.start_percent, 1.0): + hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + scheduled_hooks[hook] = hook_schedule # hooks should not have their schedules in a list of tuples all_ranges: list[tuple[float, float]] = [] for range_kfs in scheduled_hooks.values(): @@ -336,7 +401,7 @@ class HookGroup: hook.reset() @staticmethod - def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': + def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup: actual: list[HookGroup] = [] for group in hooks_list: if group is not None: @@ -433,7 +498,7 @@ class HookKeyframeGroup: c._set_first_as_current() return c - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) @@ -522,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list: sorted_list.extend(object_list) return sorted_list +def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None): + # if no hooks or is not a ModelPatcher for sampling, return empty dict + if hooks is None or model.is_clip: + return {} + if transformer_options is None: + transformer_options = {} + for hook in hooks.get_type(EnumHookType.TransformerOptions): + hook: TransformerOptionsHook + hook.on_apply_hooks(model, transformer_options) + return transformer_options + def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): hook_group = HookGroup() hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) @@ -548,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float hook.need_weight_init = False return hook_group -def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): +def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True): if model is None: return None patches_model: dict[str, torch.Tensor] = model.model.state_dict() @@ -560,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T return patches_model # NOTE: this function shows how to register weight hooks directly on the ModelPatchers -def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], +def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): key_map = {} if model is not None: @@ -612,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H else: c_dict[hooks_key] = cache[hooks_tuple] -def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): +def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True, + cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): c = [] - hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} + if cache is None: + cache = {} for t in conditioning: n = [t[0], t[1].copy()] for k in values: if append_hooks and k == 'hooks': - _combine_hooks_from_values(n[1], values, hooks_combine_cache) + _combine_hooks_from_values(n[1], values, cache) else: n[1][k] = values[k] c.append(n) return c -def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): +def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): if hooks is None: return cond - return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) + return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache) def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): if timestep_range is None: @@ -664,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list): def set_conds_props(conds: list, strength: float, set_cond_area: str, mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): final_conds = [] + cache = {} for c in conds: # first, apply lora_hook to conditioning, if provided - c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) + c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to conditioning c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) # apply timesteps, if present @@ -678,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str, def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, masked_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) + masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to new conditioning, if provided masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) # apply timesteps, if present @@ -692,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1. def set_default_conds_and_combine(conds: list, new_conds: list, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, new_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) + new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache) # next, add default_cond key to cond so that during sampling, it can be identified new_c = conditioning_set_values(new_c, {'default': True}) # apply timesteps, if present diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e886bdbb..0501f7b3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -210,7 +210,7 @@ class ModelPatcher: self.injections: dict[str, list[PatcherInjection]] = {} self.hook_patches: dict[comfy.hooks._HookRef] = {} - self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} + self.hook_patches_backup: dict[comfy.hooks._HookRef] = None self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} self.current_hooks: Optional[comfy.hooks.HookGroup] = None @@ -282,7 +282,7 @@ class ModelPatcher: n.injections[k] = i.copy() # hooks n.hook_patches = create_hook_patches_clone(self.hook_patches) - n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) + n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup for group in self.cached_hook_patches: n.cached_hook_patches[group] = {} for k in self.cached_hook_patches[group]: @@ -855,6 +855,9 @@ class ModelPatcher: if key in self.injections: self.injections.pop(key) + def get_injections(self, key: str): + return self.injections.get(key, None) + def set_additional_models(self, key: str, models: list['ModelPatcher']): self.additional_models[key] = models @@ -925,9 +928,9 @@ class ModelPatcher: callback(self, timestep) def restore_hook_patches(self): - if len(self.hook_patches_backup) > 0: + if self.hook_patches_backup is not None: self.hook_patches = self.hook_patches_backup - self.hook_patches_backup = {} + self.hook_patches_backup = None def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode @@ -953,25 +956,26 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, + registered: comfy.hooks.HookGroup = None): self.restore_hook_patches() - registered_hooks: list[comfy.hooks.Hook] = [] - # handle WrapperHooks, if model_options provided - if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): - hook.add_hook_patches(self, model_options, target, registered_hooks) + if registered is None: + registered = comfy.hooks.HookGroup() # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) + else: + registered.add(hook) if len(weight_hooks_to_register) > 0: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target) + callback(self, hooks, target_dict, model_options, registered) + return registered def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): @@ -1022,11 +1026,11 @@ class ModelPatcher: def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): # TODO: return transformer_options dict with any additions from hooks if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) self.patch_hooks(hooks=hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): callback(self, hooks) - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) def patch_hooks(self, hooks: comfy.hooks.HookGroup): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index ac973536..b70e5e63 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type): models += [c[model_type]] return models -def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): +def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup): # get hooks from conds, and collect cnets so they can be checked for extra_hooks cnets: list[ControlBase] = [] for c in cond: if 'hooks' in c: for hook in c['hooks'].hooks: - hook: comfy.hooks.Hook - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) if 'control' in c: cnets.append(c['control']) @@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) if extra_hooks is not None: for hook in extra_hooks.hooks: - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) - return hooks_dict + return full_hooks def convert_cond(cond): out = [] @@ -73,13 +70,11 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") add_models += get_models_from_cond(conds[k], "additional_models") - get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -90,11 +85,20 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] - models = control_models + gligen + add_models + hook_models + models = control_models + gligen + add_models return models, inference_memory +def get_additional_models_from_model_options(model_options: dict[str]=None): + """loads additional models from registered AddModels hooks""" + models = [] + if model_options is not None and "registered_hooks" in model_options: + registered: comfy.hooks.HookGroup = model_options["registered_hooks"] + for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook + models.extend(hook.models) + return models + def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: @@ -102,9 +106,10 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): - real_model: 'BaseModel' = None +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): + real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) + models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory @@ -123,12 +128,35 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): + ''' + Registers hooks from conds. + ''' # check for hooks in conds - if not registered, see if can be applied - hooks = {} + hooks = comfy.hooks.HookGroup() for k in conds: get_hooks_from_cond(conds[k], hooks) # add wrappers and callbacks from ModelPatcher to transformer_options model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) - # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) + # begin registering hooks + registered = comfy.hooks.HookGroup() + target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model) + # handle all TransformerOptionsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): + hook: comfy.hooks.TransformerOptionsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all AddModelsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all WeightHooks by registering on ModelPatcher + model.register_all_hook_patches(hooks, target_dict, model_options, registered) + # add registered_hooks onto model_options for further reference + if len(registered) > 0: + model_options["registered_hooks"] = registered + # merge original wrappers and callbacks with hooked wrappers and callbacks + to_load_options: dict[str] = model_options.setdefault("to_load_options", {}) + for wc_name in ["wrappers", "callbacks"]: + comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], + copy_dict1=False) + return to_load_options diff --git a/comfy/samplers.py b/comfy/samplers.py index af2b8e11..5cc33a7d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]): for cond in conds_to_modify: cond['hooks'] = hooks +def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]): + '''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for + HookGroups that have the same reference.''' + registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None) + # if None were registered, make sure all hooks are cleaned from conds + if registered is None: + for k in conds: + for kk in conds[k]: + kk.pop('hooks', None) + return + # find conds that contain hooks to be replaced - group by common HookGroup refs + hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {} + for k in conds: + for kk in conds[k]: + hooks: comfy.hooks.HookGroup = kk.get('hooks', None) + if hooks is not None: + if not hooks.is_subset_of(registered): + to_replace = hook_replacement.setdefault(hooks, []) + to_replace.append(kk) + # for each hook to replace, create a new proper HookGroup and assign to all common conds + for hooks, conds_to_modify in hook_replacement.items(): + new_hooks = hooks.new_with_common_hooks(registered) + if len(new_hooks) == 0: + new_hooks = None + for kk in conds_to_modify: + kk['hooks'] = new_hooks + def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): hooks_set = set() @@ -819,9 +846,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): return len(hooks_set) +def cast_to_load_options(model_options: dict[str], device=None, dtype=None): + ''' + If any patches from hooks, wrappers, or callbacks have .to to be called, call it. + ''' + if model_options is None: + return + to_load_options = model_options.get("to_load_options", None) + if to_load_options is None: + return + + casts = [] + if device is not None: + casts.append(device) + if dtype is not None: + casts.append(dtype) + # if nothing to apply, do nothing + if len(casts) == 0: + return + + # try to call .to on patches + if "patches" in to_load_options: + patches = to_load_options["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + for cast in casts: + patch_list[i] = patch_list[i].to(cast) + if "patches_replace" in to_load_options: + patches = to_load_options["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + for cast in casts: + patch_list[k] = patch_list[k].to(cast) + # try to call .to on any wrappers/callbacks + wrappers_and_callbacks = ["wrappers", "callbacks"] + for wc_name in wrappers_and_callbacks: + if wc_name in to_load_options: + wc: dict[str, list] = to_load_options[wc_name] + for wc_dict in wc.values(): + for wc_list in wc_dict.values(): + for i in range(len(wc_list)): + if hasattr(wc_list[i], "to"): + for cast in casts: + wc_list[i] = wc_list[i].to(cast) + + class CFGGuider: - def __init__(self, model_patcher): - self.model_patcher: 'ModelPatcher' = model_patcher + def __init__(self, model_patcher: ModelPatcher): + self.model_patcher = model_patcher self.model_options = model_patcher.model_options self.original_conds = {} self.cfg = 1.0 @@ -861,7 +937,7 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device if denoise_mask is not None: @@ -870,6 +946,7 @@ class CFGGuider: noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) try: self.model_patcher.pre_run() @@ -899,6 +976,7 @@ class CFGGuider: if get_total_hook_groups_in_conds(self.conds) <= 1: self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) + filter_registered_hooks_on_conds(self.conds, self.model_options) executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( self.outer_sample, self, @@ -906,6 +984,7 @@ class CFGGuider: ) output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: + cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 9d9d4837..1edc06f3 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -246,7 +246,7 @@ class SetClipHooks: CATEGORY = "advanced/hooks/clip" FUNCTION = "apply_hooks" - def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): + def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: clip = clip.clone() if apply_to_conds: @@ -255,7 +255,7 @@ class SetClipHooks: clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) + clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: From 42086af123554020278f2c68f06776295530198a Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Sat, 11 Jan 2025 12:52:46 -0500 Subject: [PATCH 5/6] Merge ruff.toml into pyproject.toml (#6431) --- pyproject.toml | 12 ++++++++++++ ruff.toml | 15 --------------- 2 files changed, 12 insertions(+), 15 deletions(-) delete mode 100644 ruff.toml diff --git a/pyproject.toml b/pyproject.toml index 1d9d7b3f..b747d6ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,3 +9,15 @@ requires-python = ">=3.9" homepage = "https://www.comfy.org/" repository = "https://github.com/comfyanonymous/ComfyUI" documentation = "https://docs.comfy.org/" + +[tool.ruff] +lint.select = [ + "S307", # suspicious-eval-usage + "S102", # exec + "T", # print-usage + "W", + # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. + # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f + "F", +] +exclude = ["*.ipynb"] diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 808a2ad0..00000000 --- a/ruff.toml +++ /dev/null @@ -1,15 +0,0 @@ -# Disable all rules by default -lint.ignore = ["ALL"] - -# Enable specific rules -lint.select = [ - "S307", # suspicious-eval-usage - "S102", # exec - "T", # print-usage - "W", - # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. - # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f - "F", -] - -exclude = ["*.ipynb"] From b9d9bcba1418711f13d7e432605f85303d900723 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 12 Jan 2025 03:19:51 +0300 Subject: [PATCH 6/6] fixed a bug where a relative path was not converted to a full path (#6395) Signed-off-by: bigcat88 --- tests-unit/utils/extra_config_test.py | 181 +++++++++++++++++++++++++- utils/extra_config.py | 6 +- 2 files changed, 183 insertions(+), 4 deletions(-) diff --git a/tests-unit/utils/extra_config_test.py b/tests-unit/utils/extra_config_test.py index 143e1f2e..b23f5bd0 100644 --- a/tests-unit/utils/extra_config_test.py +++ b/tests-unit/utils/extra_config_test.py @@ -1,11 +1,22 @@ import pytest import yaml import os +import sys from unittest.mock import Mock, patch, mock_open from utils.extra_config import load_extra_path_config import folder_paths + +@pytest.fixture() +def clear_folder_paths(): + # Clear the global dictionary before each test to ensure isolation + original = folder_paths.folder_names_and_paths.copy() + folder_paths.folder_names_and_paths.clear() + yield + folder_paths.folder_names_and_paths = original + + @pytest.fixture def mock_yaml_content(): return { @@ -15,10 +26,12 @@ def mock_yaml_content(): } } + @pytest.fixture def mock_expanded_home(): return '/home/user' + @pytest.fixture def yaml_config_with_appdata(): return """ @@ -27,20 +40,33 @@ def yaml_config_with_appdata(): checkpoints: 'models/checkpoints' """ + @pytest.fixture def mock_yaml_content_appdata(yaml_config_with_appdata): return yaml.safe_load(yaml_config_with_appdata) + @pytest.fixture def mock_expandvars_appdata(): mock = Mock() - mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming') + + def expandvars(path): + if '%APPDATA%' in path: + if sys.platform == 'win32': + return path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming') + else: + return path.replace('%APPDATA%', '/Users/TestUser/AppData/Roaming') + return path + + mock.side_effect = expandvars return mock + @pytest.fixture def mock_add_model_folder_path(): return Mock() + @pytest.fixture def mock_expanduser(mock_expanded_home): def _expanduser(path): @@ -49,10 +75,12 @@ def mock_expanduser(mock_expanded_home): return path return _expanduser + @pytest.fixture def mock_yaml_safe_load(mock_yaml_content): return Mock(return_value=mock_yaml_content) + @patch('builtins.open', new_callable=mock_open, read_data="dummy file content") def test_load_extra_model_paths_expands_userpath( mock_file, @@ -88,6 +116,7 @@ def test_load_extra_model_paths_expands_userpath( # Check if open was called with the correct file path mock_file.assert_called_once_with(dummy_yaml_file_name, 'r') + @patch('builtins.open', new_callable=mock_open) def test_load_extra_model_paths_expands_appdata( mock_file, @@ -111,7 +140,10 @@ def test_load_extra_model_paths_expands_appdata( dummy_yaml_file_name = 'dummy_path.yaml' load_extra_path_config(dummy_yaml_file_name) - expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI' + if sys.platform == "win32": + expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI' + else: + expected_base_path = '/Users/TestUser/AppData/Roaming/ComfyUI' expected_calls = [ ('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False), ] @@ -124,3 +156,148 @@ def test_load_extra_model_paths_expands_appdata( # Verify that expandvars was called assert mock_expandvars_appdata.called + + +@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content") +@patch("yaml.safe_load") +def test_load_extra_path_config_relative_base_path( + mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path +): + """ + Test that when 'base_path' is a relative path in the YAML, it is joined to the YAML file directory, and then + the items in the config are correctly converted to absolute paths. + """ + sub_folder = "./my_rel_base" + config_data = { + "some_model_folder": { + "base_path": sub_folder, + "is_default": True, + "checkpoints": "checkpoints", + "some_key": "some_value" + } + } + mock_yaml_load.return_value = config_data + + dummy_yaml_name = "dummy_file.yaml" + + def fake_abspath(path): + if path == dummy_yaml_name: + # If it's the YAML path, treat it like it lives in tmp_path + return os.path.join(str(tmp_path), dummy_yaml_name) + return os.path.join(str(tmp_path), path) # Otherwise, do a normal join relative to tmp_path + + def fake_dirname(path): + # We expect path to be the result of fake_abspath(dummy_yaml_name) + if path.endswith(dummy_yaml_name): + return str(tmp_path) + return os.path.dirname(path) + + monkeypatch.setattr(os.path, "abspath", fake_abspath) + monkeypatch.setattr(os.path, "dirname", fake_dirname) + + load_extra_path_config(dummy_yaml_name) + + expected_checkpoints = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "checkpoints")) + expected_some_value = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "some_value")) + + actual_paths = folder_paths.folder_names_and_paths["checkpoints"][0] + assert len(actual_paths) == 1, "Should have one path added for 'checkpoints'." + assert actual_paths[0] == expected_checkpoints + + actual_paths = folder_paths.folder_names_and_paths["some_key"][0] + assert len(actual_paths) == 1, "Should have one path added for 'some_key'." + assert actual_paths[0] == expected_some_value + + +@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content") +@patch("yaml.safe_load") +def test_load_extra_path_config_absolute_base_path( + mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path +): + """ + Test that when 'base_path' is an absolute path, each subdirectory is joined with that absolute path, + rather than being relative to the YAML's directory. + """ + abs_base = os.path.join(str(tmp_path), "abs_base") + config_data = { + "some_absolute_folder": { + "base_path": abs_base, # <-- absolute + "is_default": True, + "loras": "loras_folder", + "embeddings": "embeddings_folder" + } + } + mock_yaml_load.return_value = config_data + + dummy_yaml_name = "dummy_abs.yaml" + + def fake_abspath(path): + if path == dummy_yaml_name: + # If it's the YAML path, treat it like it is in tmp_path + return os.path.join(str(tmp_path), dummy_yaml_name) + return path # For absolute base, we just return path directly + + def fake_dirname(path): + return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path) + + monkeypatch.setattr(os.path, "abspath", fake_abspath) + monkeypatch.setattr(os.path, "dirname", fake_dirname) + + load_extra_path_config(dummy_yaml_name) + + # Expect the final paths to be /loras_folder and /embeddings_folder + expected_loras = os.path.join(abs_base, "loras_folder") + expected_embeddings = os.path.join(abs_base, "embeddings_folder") + + actual_loras = folder_paths.folder_names_and_paths["loras"][0] + assert len(actual_loras) == 1, "Should have one path for 'loras'." + assert actual_loras[0] == os.path.abspath(expected_loras) + + actual_embeddings = folder_paths.folder_names_and_paths["embeddings"][0] + assert len(actual_embeddings) == 1, "Should have one path for 'embeddings'." + assert actual_embeddings[0] == os.path.abspath(expected_embeddings) + + +@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content") +@patch("yaml.safe_load") +def test_load_extra_path_config_no_base_path( + mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path +): + """ + Test that if 'base_path' is not present, each path is joined + with the directory of the YAML file (unless it's already absolute). + """ + config_data = { + "some_folder_without_base": { + "is_default": True, + "text_encoders": "clip", + "diffusion_models": "unet" + } + } + mock_yaml_load.return_value = config_data + + dummy_yaml_name = "dummy_no_base.yaml" + + def fake_abspath(path): + if path == dummy_yaml_name: + return os.path.join(str(tmp_path), dummy_yaml_name) + return os.path.join(str(tmp_path), path) + + def fake_dirname(path): + return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path) + + monkeypatch.setattr(os.path, "abspath", fake_abspath) + monkeypatch.setattr(os.path, "dirname", fake_dirname) + + load_extra_path_config(dummy_yaml_name) + + expected_clip = os.path.join(str(tmp_path), "clip") + expected_unet = os.path.join(str(tmp_path), "unet") + + actual_text_encoders = folder_paths.folder_names_and_paths["text_encoders"][0] + assert len(actual_text_encoders) == 1, "Should have one path for 'text_encoders'." + assert actual_text_encoders[0] == os.path.abspath(expected_clip) + + actual_diffusion = folder_paths.folder_names_and_paths["diffusion_models"][0] + assert len(actual_diffusion) == 1, "Should have one path for 'diffusion_models'." + assert actual_diffusion[0] == os.path.abspath(expected_unet) diff --git a/utils/extra_config.py b/utils/extra_config.py index 415db042..d7b59285 100644 --- a/utils/extra_config.py +++ b/utils/extra_config.py @@ -6,6 +6,7 @@ import logging def load_extra_path_config(yaml_path): with open(yaml_path, 'r') as stream: config = yaml.safe_load(stream) + yaml_dir = os.path.dirname(os.path.abspath(yaml_path)) for c in config: conf = config[c] if conf is None: @@ -14,6 +15,8 @@ def load_extra_path_config(yaml_path): if "base_path" in conf: base_path = conf.pop("base_path") base_path = os.path.expandvars(os.path.expanduser(base_path)) + if not os.path.isabs(base_path): + base_path = os.path.abspath(os.path.join(yaml_dir, base_path)) is_default = False if "is_default" in conf: is_default = conf.pop("is_default") @@ -22,10 +25,9 @@ def load_extra_path_config(yaml_path): if len(y) == 0: continue full_path = y - if base_path is not None: + if base_path: full_path = os.path.join(base_path, full_path) elif not os.path.isabs(full_path): - yaml_dir = os.path.dirname(os.path.abspath(yaml_path)) full_path = os.path.abspath(os.path.join(yaml_dir, y)) logging.info("Adding extra search path {} {}".format(x, full_path)) folder_paths.add_model_folder_path(x, full_path, is_default)