mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model) * Initial changes to calc_cond_batch to eventually support hook_patches * Added current_patcher property to BaseModel * Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature * Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch * Added default_conds support in calc_cond_batch func * Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring * Made CLIP work with hook patches * Added initial hook scheduling nodes, small renaming/refactoring * Fixed MaxSpeed and default conds implementations * Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time * Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node * Initial work on adding 'model_as_lora' lora type to calculate_weight * Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models * Fix incorrect ref to create_hook_patches_clone after moving function * Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks * Added wrappers to ModelPatcher to facilitate standardized function wrapping * Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type * Fix skip_until_exit logic bug breaking injection after first run of model * Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load * Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers * Refactored callbacks+wrappers to allow storing lists by id * Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks * Added get_attachment func on ModelPatcher * Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup * Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref * Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations * Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE) * Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo) * Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending) * Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling * Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances * Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks * Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge * Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes * Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges * Optimized CLIP hook scheduling to treat same strength as same keyframe * Less fragile memory management. * Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict * Fix issue. * Remove useless function. * Prevent and detect some types of memory leaks. * Run garbage collector when switching workflow if needed. * Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py * Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers * Fix issue. * Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming * Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options * Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired * Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality * Fixed existing weight hook_patches (pre-registered) not working properly for CLIP * Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky * Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3) * Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher * Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate * Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True * Fix cached_hook_patches not respecting target_device/memory_counter results * Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches * Remove unnecessary torch.no_grad calls for hook patches * Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists * For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes * Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call * Made encode_from_tokens_scheduled work when no hooks are set on patcher * Small cleanup of comments * Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case * On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs * Allow both FLOATS and FLOAT for floats_strength input * Revert change, does not work * Made patch_hook_weight_to_device respect set_func and convert_func * Make discard_model_sampling True by default * Add changes manually from 'master' so merge conflict resolution goes more smoothly * Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call * Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip * Made nodes in nodes_hooks be marked as experimental (beta) * Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references * Made finalize_default_conds area math consistent with other sampling code * Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks' * Remove a couple old TODO's and a no longer necessary workaround
This commit is contained in:
parent
79d5ceae6e
commit
0ee322ec5f
@ -36,6 +36,10 @@ import comfy.cldm.mmdit
|
|||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet
|
import comfy.ldm.flux.controlnet
|
||||||
import comfy.cldm.dit_embedder
|
import comfy.cldm.dit_embedder
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.hooks import HookGroup
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
@ -78,6 +82,7 @@ class ControlBase:
|
|||||||
self.concat_mask = False
|
self.concat_mask = False
|
||||||
self.extra_concat_orig = []
|
self.extra_concat_orig = []
|
||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
|
self.extra_hooks: HookGroup = None
|
||||||
self.preprocess_image = lambda a: a
|
self.preprocess_image = lambda a: a
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
@ -116,6 +121,14 @@ class ControlBase:
|
|||||||
out += self.previous_controlnet.get_models()
|
out += self.previous_controlnet.get_models()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def get_extra_hooks(self):
|
||||||
|
out = []
|
||||||
|
if self.extra_hooks is not None:
|
||||||
|
out.append(self.extra_hooks)
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
out += self.previous_controlnet.get_extra_hooks()
|
||||||
|
return out
|
||||||
|
|
||||||
def copy_to(self, c):
|
def copy_to(self, c):
|
||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
@ -130,6 +143,7 @@ class ControlBase:
|
|||||||
c.strength_type = self.strength_type
|
c.strength_type = self.strength_type
|
||||||
c.concat_mask = self.concat_mask
|
c.concat_mask = self.concat_mask
|
||||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||||
|
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
|
||||||
c.preprocess_image = self.preprocess_image
|
c.preprocess_image = self.preprocess_image
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
@ -200,10 +214,10 @@ class ControlNet(ControlBase):
|
|||||||
self.concat_mask = concat_mask
|
self.concat_mask = concat_mask
|
||||||
self.preprocess_image = preprocess_image
|
self.preprocess_image = preprocess_image
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||||
|
|
||||||
if self.timestep_range is not None:
|
if self.timestep_range is not None:
|
||||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||||
@ -758,10 +772,10 @@ class T2IAdapter(ControlBase):
|
|||||||
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
||||||
return width, height
|
return width, height
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||||
|
|
||||||
if self.timestep_range is not None:
|
if self.timestep_range is not None:
|
||||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||||
|
690
comfy/hooks.py
Normal file
690
comfy/hooks.py
Normal file
@ -0,0 +1,690 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
import enum
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
|
from comfy.sd import CLIP
|
||||||
|
import comfy.lora
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
|
from node_helpers import conditioning_set_values
|
||||||
|
|
||||||
|
class EnumHookMode(enum.Enum):
|
||||||
|
MinVram = "minvram"
|
||||||
|
MaxSpeed = "maxspeed"
|
||||||
|
|
||||||
|
class EnumHookType(enum.Enum):
|
||||||
|
Weight = "weight"
|
||||||
|
Patch = "patch"
|
||||||
|
ObjectPatch = "object_patch"
|
||||||
|
AddModels = "add_models"
|
||||||
|
Callbacks = "callbacks"
|
||||||
|
Wrappers = "wrappers"
|
||||||
|
SetInjections = "add_injections"
|
||||||
|
|
||||||
|
class EnumWeightTarget(enum.Enum):
|
||||||
|
Model = "model"
|
||||||
|
Clip = "clip"
|
||||||
|
|
||||||
|
class _HookRef:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# NOTE: this is an example of how the should_register function should look
|
||||||
|
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class Hook:
|
||||||
|
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||||
|
hook_keyframe: 'HookKeyframeGroup'=None):
|
||||||
|
self.hook_type = hook_type
|
||||||
|
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||||
|
self.hook_id = hook_id
|
||||||
|
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||||
|
self.custom_should_register = default_should_register
|
||||||
|
self.auto_apply_to_nonpositive = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def strength(self):
|
||||||
|
return self.hook_keyframe.strength
|
||||||
|
|
||||||
|
def initialize_timesteps(self, model: 'BaseModel'):
|
||||||
|
self.reset()
|
||||||
|
self.hook_keyframe.initialize_timesteps(model)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.hook_keyframe.reset()
|
||||||
|
|
||||||
|
def clone(self, subtype: Callable=None):
|
||||||
|
if subtype is None:
|
||||||
|
subtype = type(self)
|
||||||
|
c: Hook = subtype()
|
||||||
|
c.hook_type = self.hook_type
|
||||||
|
c.hook_ref = self.hook_ref
|
||||||
|
c.hook_id = self.hook_id
|
||||||
|
c.hook_keyframe = self.hook_keyframe
|
||||||
|
c.custom_should_register = self.custom_should_register
|
||||||
|
# TODO: make this do something
|
||||||
|
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||||
|
return c
|
||||||
|
|
||||||
|
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||||
|
return self.custom_should_register(self, model, model_options, target, registered)
|
||||||
|
|
||||||
|
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||||
|
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||||
|
|
||||||
|
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __eq__(self, other: 'Hook'):
|
||||||
|
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.hook_ref)
|
||||||
|
|
||||||
|
class WeightHook(Hook):
|
||||||
|
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||||
|
super().__init__(hook_type=EnumHookType.Weight)
|
||||||
|
self.weights: dict = None
|
||||||
|
self.weights_clip: dict = None
|
||||||
|
self.need_weight_init = True
|
||||||
|
self._strength_model = strength_model
|
||||||
|
self._strength_clip = strength_clip
|
||||||
|
|
||||||
|
@property
|
||||||
|
def strength_model(self):
|
||||||
|
return self._strength_model * self.strength
|
||||||
|
|
||||||
|
@property
|
||||||
|
def strength_clip(self):
|
||||||
|
return self._strength_clip * self.strength
|
||||||
|
|
||||||
|
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||||
|
if not self.should_register(model, model_options, target, registered):
|
||||||
|
return False
|
||||||
|
weights = None
|
||||||
|
if target == EnumWeightTarget.Model:
|
||||||
|
strength = self._strength_model
|
||||||
|
else:
|
||||||
|
strength = self._strength_clip
|
||||||
|
|
||||||
|
if self.need_weight_init:
|
||||||
|
key_map = {}
|
||||||
|
if target == EnumWeightTarget.Model:
|
||||||
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
|
else:
|
||||||
|
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||||
|
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||||
|
else:
|
||||||
|
if target == EnumWeightTarget.Model:
|
||||||
|
weights = self.weights
|
||||||
|
else:
|
||||||
|
weights = self.weights_clip
|
||||||
|
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||||
|
registered.append(self)
|
||||||
|
return True
|
||||||
|
# TODO: add logs about any keys that were not applied
|
||||||
|
|
||||||
|
def clone(self, subtype: Callable=None):
|
||||||
|
if subtype is None:
|
||||||
|
subtype = type(self)
|
||||||
|
c: WeightHook = super().clone(subtype)
|
||||||
|
c.weights = self.weights
|
||||||
|
c.weights_clip = self.weights_clip
|
||||||
|
c.need_weight_init = self.need_weight_init
|
||||||
|
c._strength_model = self._strength_model
|
||||||
|
c._strength_clip = self._strength_clip
|
||||||
|
return c
|
||||||
|
|
||||||
|
class PatchHook(Hook):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(hook_type=EnumHookType.Patch)
|
||||||
|
self.patches: dict = None
|
||||||
|
|
||||||
|
def clone(self, subtype: Callable=None):
|
||||||
|
if subtype is None:
|
||||||
|
subtype = type(self)
|
||||||
|
c: PatchHook = super().clone(subtype)
|
||||||
|
c.patches = self.patches
|
||||||
|
return c
|
||||||
|
# TODO: add functionality
|
||||||
|
|
||||||
|
class ObjectPatchHook(Hook):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||||
|
self.object_patches: dict = None
|
||||||
|
|
||||||
|
def clone(self, subtype: Callable=None):
|
||||||
|
if subtype is None:
|
||||||
|
subtype = type(self)
|
||||||
|
c: ObjectPatchHook = super().clone(subtype)
|
||||||
|
c.object_patches = self.object_patches
|
||||||
|
return c
|
||||||
|
# TODO: add functionality
|
||||||
|
|
||||||
|
class AddModelsHook(Hook):
|
||||||
|
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
||||||
|
super().__init__(hook_type=EnumHookType.AddModels)
|
||||||
|
self.key = key
|
||||||
|
self.models = models
|
||||||
|
self.append_when_same = True
|
||||||
|
|
||||||
|
def clone(self, subtype: Callable=None):
|
||||||
|
if subtype is None:
|
||||||
|
subtype = type(self)
|
||||||
|
c: AddModelsHook = super().clone(subtype)
|
||||||
|
c.key = self.key
|
||||||
|
c.models = self.models.copy() if self.models else self.models
|
||||||
|
c.append_when_same = self.append_when_same
|
||||||
|
return c
|
||||||
|
# TODO: add functionality
|
||||||
|
|
||||||
|
class CallbackHook(Hook):
|
||||||
|
def __init__(self, key: str=None, callback: Callable=None):
|
||||||
|
super().__init__(hook_type=EnumHookType.Callbacks)
|
||||||
|
self.key = key
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
def clone(self, subtype: Callable=None):
|
||||||
|
if subtype is None:
|
||||||
|
subtype = type(self)
|
||||||
|
c: CallbackHook = super().clone(subtype)
|
||||||
|
c.key = self.key
|
||||||
|
c.callback = self.callback
|
||||||
|
return c
|
||||||
|
# TODO: add functionality
|
||||||
|
|
||||||
|
class WrapperHook(Hook):
|
||||||
|
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||||
|
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||||
|
self.wrappers_dict = wrappers_dict
|
||||||
|
|
||||||
|
def clone(self, subtype: Callable=None):
|
||||||
|
if subtype is None:
|
||||||
|
subtype = type(self)
|
||||||
|
c: WrapperHook = super().clone(subtype)
|
||||||
|
c.wrappers_dict = self.wrappers_dict
|
||||||
|
return c
|
||||||
|
|
||||||
|
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||||
|
if not self.should_register(model, model_options, target, registered):
|
||||||
|
return False
|
||||||
|
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||||
|
registered.append(self)
|
||||||
|
return True
|
||||||
|
|
||||||
|
class SetInjectionsHook(Hook):
|
||||||
|
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
||||||
|
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||||
|
self.key = key
|
||||||
|
self.injections = injections
|
||||||
|
|
||||||
|
def clone(self, subtype: Callable=None):
|
||||||
|
if subtype is None:
|
||||||
|
subtype = type(self)
|
||||||
|
c: SetInjectionsHook = super().clone(subtype)
|
||||||
|
c.key = self.key
|
||||||
|
c.injections = self.injections.copy() if self.injections else self.injections
|
||||||
|
return c
|
||||||
|
|
||||||
|
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||||
|
# TODO: add functionality
|
||||||
|
pass
|
||||||
|
|
||||||
|
class HookGroup:
|
||||||
|
def __init__(self):
|
||||||
|
self.hooks: list[Hook] = []
|
||||||
|
|
||||||
|
def add(self, hook: Hook):
|
||||||
|
if hook not in self.hooks:
|
||||||
|
self.hooks.append(hook)
|
||||||
|
|
||||||
|
def contains(self, hook: Hook):
|
||||||
|
return hook in self.hooks
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c = HookGroup()
|
||||||
|
for hook in self.hooks:
|
||||||
|
c.add(hook.clone())
|
||||||
|
return c
|
||||||
|
|
||||||
|
def clone_and_combine(self, other: 'HookGroup'):
|
||||||
|
c = self.clone()
|
||||||
|
if other is not None:
|
||||||
|
for hook in other.hooks:
|
||||||
|
c.add(hook.clone())
|
||||||
|
return c
|
||||||
|
|
||||||
|
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||||
|
if hook_kf is None:
|
||||||
|
hook_kf = HookKeyframeGroup()
|
||||||
|
else:
|
||||||
|
hook_kf = hook_kf.clone()
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.hook_keyframe = hook_kf
|
||||||
|
|
||||||
|
def get_dict_repr(self):
|
||||||
|
d: dict[EnumHookType, dict[Hook, None]] = {}
|
||||||
|
for hook in self.hooks:
|
||||||
|
with_type = d.setdefault(hook.hook_type, {})
|
||||||
|
with_type[hook] = None
|
||||||
|
return d
|
||||||
|
|
||||||
|
def get_hooks_for_clip_schedule(self):
|
||||||
|
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||||
|
for hook in self.hooks:
|
||||||
|
# only care about WeightHooks, for now
|
||||||
|
if hook.hook_type == EnumHookType.Weight:
|
||||||
|
hook_schedule = []
|
||||||
|
# if no hook keyframes, assign default value
|
||||||
|
if len(hook.hook_keyframe.keyframes) == 0:
|
||||||
|
hook_schedule.append(((0.0, 1.0), None))
|
||||||
|
scheduled_hooks[hook] = hook_schedule
|
||||||
|
continue
|
||||||
|
# find ranges of values
|
||||||
|
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||||
|
for keyframe in hook.hook_keyframe.keyframes:
|
||||||
|
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
# create final range, assuming last start_percent was not 1.0
|
||||||
|
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||||
|
scheduled_hooks[hook] = hook_schedule
|
||||||
|
# hooks should not have their schedules in a list of tuples
|
||||||
|
all_ranges: list[tuple[float, float]] = []
|
||||||
|
for range_kfs in scheduled_hooks.values():
|
||||||
|
for t_range, keyframe in range_kfs:
|
||||||
|
all_ranges.append(t_range)
|
||||||
|
# turn list of ranges into boundaries
|
||||||
|
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
|
||||||
|
boundaries_set.add(0.0)
|
||||||
|
boundaries = sorted(boundaries_set)
|
||||||
|
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
|
||||||
|
# with real ranges defined, give appropriate hooks w/ keyframes for each range
|
||||||
|
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
|
||||||
|
for t_range in real_ranges:
|
||||||
|
hooks_schedule = []
|
||||||
|
for hook, val in scheduled_hooks.items():
|
||||||
|
keyframe = None
|
||||||
|
# check if is a keyframe that works for the current t_range
|
||||||
|
for stored_range, stored_kf in val:
|
||||||
|
# if stored start is less than current end, then fits - give it assigned keyframe
|
||||||
|
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
|
||||||
|
keyframe = stored_kf
|
||||||
|
break
|
||||||
|
hooks_schedule.append((hook, keyframe))
|
||||||
|
scheduled_keyframes.append((t_range, hooks_schedule))
|
||||||
|
return scheduled_keyframes
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.reset()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||||
|
actual: list[HookGroup] = []
|
||||||
|
for group in hooks_list:
|
||||||
|
if group is not None:
|
||||||
|
actual.append(group)
|
||||||
|
if len(actual) < require_count:
|
||||||
|
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
|
||||||
|
# if no hooks, then return None
|
||||||
|
if len(actual) == 0:
|
||||||
|
return None
|
||||||
|
# if only 1 hook, just return itself without cloning
|
||||||
|
elif len(actual) == 1:
|
||||||
|
return actual[0]
|
||||||
|
final_hook: HookGroup = None
|
||||||
|
for hook in actual:
|
||||||
|
if final_hook is None:
|
||||||
|
final_hook = hook.clone()
|
||||||
|
else:
|
||||||
|
final_hook = final_hook.clone_and_combine(hook)
|
||||||
|
return final_hook
|
||||||
|
|
||||||
|
|
||||||
|
class HookKeyframe:
|
||||||
|
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
|
||||||
|
self.strength = strength
|
||||||
|
# scheduling
|
||||||
|
self.start_percent = float(start_percent)
|
||||||
|
self.start_t = 999999999.9
|
||||||
|
self.guarantee_steps = guarantee_steps
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c = HookKeyframe(strength=self.strength,
|
||||||
|
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||||
|
c.start_t = self.start_t
|
||||||
|
return c
|
||||||
|
|
||||||
|
class HookKeyframeGroup:
|
||||||
|
def __init__(self):
|
||||||
|
self.keyframes: list[HookKeyframe] = []
|
||||||
|
self._current_keyframe: HookKeyframe = None
|
||||||
|
self._current_used_steps = 0
|
||||||
|
self._current_index = 0
|
||||||
|
self._current_strength = None
|
||||||
|
self._curr_t = -1.
|
||||||
|
|
||||||
|
# properties shadow those of HookWeightsKeyframe
|
||||||
|
@property
|
||||||
|
def strength(self):
|
||||||
|
if self._current_keyframe is not None:
|
||||||
|
return self._current_keyframe.strength
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._current_keyframe = None
|
||||||
|
self._current_used_steps = 0
|
||||||
|
self._current_index = 0
|
||||||
|
self._current_strength = None
|
||||||
|
self.curr_t = -1.
|
||||||
|
self._set_first_as_current()
|
||||||
|
|
||||||
|
def add(self, keyframe: HookKeyframe):
|
||||||
|
# add to end of list, then sort
|
||||||
|
self.keyframes.append(keyframe)
|
||||||
|
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
|
||||||
|
self._set_first_as_current()
|
||||||
|
|
||||||
|
def _set_first_as_current(self):
|
||||||
|
if len(self.keyframes) > 0:
|
||||||
|
self._current_keyframe = self.keyframes[0]
|
||||||
|
else:
|
||||||
|
self._current_keyframe = None
|
||||||
|
|
||||||
|
def has_index(self, index: int):
|
||||||
|
return index >= 0 and index < len(self.keyframes)
|
||||||
|
|
||||||
|
def is_empty(self):
|
||||||
|
return len(self.keyframes) == 0
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c = HookKeyframeGroup()
|
||||||
|
for keyframe in self.keyframes:
|
||||||
|
c.keyframes.append(keyframe.clone())
|
||||||
|
c._set_first_as_current()
|
||||||
|
return c
|
||||||
|
|
||||||
|
def initialize_timesteps(self, model: 'BaseModel'):
|
||||||
|
for keyframe in self.keyframes:
|
||||||
|
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||||
|
|
||||||
|
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
||||||
|
if self.is_empty():
|
||||||
|
return False
|
||||||
|
if curr_t == self._curr_t:
|
||||||
|
return False
|
||||||
|
prev_index = self._current_index
|
||||||
|
prev_strength = self._current_strength
|
||||||
|
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||||
|
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
||||||
|
# if has next index, loop through and see if need to switch
|
||||||
|
if self.has_index(self._current_index+1):
|
||||||
|
for i in range(self._current_index+1, len(self.keyframes)):
|
||||||
|
eval_c = self.keyframes[i]
|
||||||
|
# check if start_t is greater or equal to curr_t
|
||||||
|
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
|
||||||
|
if eval_c.start_t >= curr_t:
|
||||||
|
self._current_index = i
|
||||||
|
self._current_strength = eval_c.strength
|
||||||
|
self._current_keyframe = eval_c
|
||||||
|
self._current_used_steps = 0
|
||||||
|
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||||
|
if self._current_keyframe.guarantee_steps > 0:
|
||||||
|
break
|
||||||
|
# if eval_c is outside the percent range, stop looking further
|
||||||
|
else: break
|
||||||
|
# update steps current context is used
|
||||||
|
self._current_used_steps += 1
|
||||||
|
# update current timestep this was performed on
|
||||||
|
self._curr_t = curr_t
|
||||||
|
# return True if keyframe changed, False if no change
|
||||||
|
return prev_index != self._current_index and prev_strength != self._current_strength
|
||||||
|
|
||||||
|
|
||||||
|
class InterpolationMethod:
|
||||||
|
LINEAR = "linear"
|
||||||
|
EASE_IN = "ease_in"
|
||||||
|
EASE_OUT = "ease_out"
|
||||||
|
EASE_IN_OUT = "ease_in_out"
|
||||||
|
|
||||||
|
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
|
||||||
|
diff = num_to - num_from
|
||||||
|
if method == cls.LINEAR:
|
||||||
|
weights = torch.linspace(num_from, num_to, length)
|
||||||
|
elif method == cls.EASE_IN:
|
||||||
|
index = torch.linspace(0, 1, length)
|
||||||
|
weights = diff * np.power(index, 2) + num_from
|
||||||
|
elif method == cls.EASE_OUT:
|
||||||
|
index = torch.linspace(0, 1, length)
|
||||||
|
weights = diff * (1 - np.power(1 - index, 2)) + num_from
|
||||||
|
elif method == cls.EASE_IN_OUT:
|
||||||
|
index = torch.linspace(0, 1, length)
|
||||||
|
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized interpolation method '{method}'.")
|
||||||
|
if reverse:
|
||||||
|
weights = weights.flip(dims=(0,))
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||||
|
if not objects:
|
||||||
|
return objects
|
||||||
|
elif len(objects) <= 1:
|
||||||
|
return [x for x in objects]
|
||||||
|
# now that we know we have to sort, do it following these rules:
|
||||||
|
# a) if objects have same value of attribute, maintain their relative order
|
||||||
|
# b) perform sorting of the groups of objects with same attributes
|
||||||
|
unique_attrs = {}
|
||||||
|
for o in objects:
|
||||||
|
val_attr = getattr(o, attr)
|
||||||
|
attr_list: list = unique_attrs.get(val_attr, list())
|
||||||
|
attr_list.append(o)
|
||||||
|
if val_attr not in unique_attrs:
|
||||||
|
unique_attrs[val_attr] = attr_list
|
||||||
|
# now that we have the unique attr values grouped together in relative order, sort them by key
|
||||||
|
sorted_attrs = dict(sorted(unique_attrs.items()))
|
||||||
|
# now flatten out the dict into a list to return
|
||||||
|
sorted_list = []
|
||||||
|
for object_list in sorted_attrs.values():
|
||||||
|
sorted_list.extend(object_list)
|
||||||
|
return sorted_list
|
||||||
|
|
||||||
|
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||||
|
hook_group = HookGroup()
|
||||||
|
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||||
|
hook_group.add(hook)
|
||||||
|
hook.weights = lora
|
||||||
|
return hook_group
|
||||||
|
|
||||||
|
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
|
||||||
|
hook_group = HookGroup()
|
||||||
|
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||||
|
hook_group.add(hook)
|
||||||
|
patches_model = None
|
||||||
|
patches_clip = None
|
||||||
|
if weights_model is not None:
|
||||||
|
patches_model = {}
|
||||||
|
for key in weights_model:
|
||||||
|
patches_model[key] = ("model_as_lora", (weights_model[key],))
|
||||||
|
if weights_clip is not None:
|
||||||
|
patches_clip = {}
|
||||||
|
for key in weights_clip:
|
||||||
|
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
|
||||||
|
hook.weights = patches_model
|
||||||
|
hook.weights_clip = patches_clip
|
||||||
|
hook.need_weight_init = False
|
||||||
|
return hook_group
|
||||||
|
|
||||||
|
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||||
|
if discard_model_sampling:
|
||||||
|
# do not include ANY model_sampling components of the model that should act as a patch
|
||||||
|
for key in list(patches_model.keys()):
|
||||||
|
if key.startswith("model_sampling"):
|
||||||
|
patches_model.pop(key, None)
|
||||||
|
return patches_model
|
||||||
|
|
||||||
|
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||||
|
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
||||||
|
strength_model: float, strength_clip: float):
|
||||||
|
key_map = {}
|
||||||
|
if model is not None:
|
||||||
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
|
if clip is not None:
|
||||||
|
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||||
|
|
||||||
|
hook_group = HookGroup()
|
||||||
|
hook = WeightHook()
|
||||||
|
hook_group.add(hook)
|
||||||
|
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
|
||||||
|
if model is not None:
|
||||||
|
new_modelpatcher = model.clone()
|
||||||
|
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
|
||||||
|
else:
|
||||||
|
k = ()
|
||||||
|
new_modelpatcher = None
|
||||||
|
|
||||||
|
if clip is not None:
|
||||||
|
new_clip = clip.clone()
|
||||||
|
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||||
|
else:
|
||||||
|
k1 = ()
|
||||||
|
new_clip = None
|
||||||
|
k = set(k)
|
||||||
|
k1 = set(k1)
|
||||||
|
for x in loaded:
|
||||||
|
if (x not in k) and (x not in k1):
|
||||||
|
print(f"NOT LOADED {x}")
|
||||||
|
return (new_modelpatcher, new_clip, hook_group)
|
||||||
|
|
||||||
|
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
||||||
|
hooks_key = 'hooks'
|
||||||
|
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
|
||||||
|
if hooks_key not in values:
|
||||||
|
return
|
||||||
|
if hooks_key not in c_dict:
|
||||||
|
hooks_value = values.get(hooks_key, None)
|
||||||
|
if hooks_value is not None:
|
||||||
|
c_dict[hooks_key] = hooks_value
|
||||||
|
return
|
||||||
|
# otherwise, need to combine with minimum duplication via cache
|
||||||
|
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
|
||||||
|
cached_hooks = cache.get(hooks_tuple, None)
|
||||||
|
if cached_hooks is None:
|
||||||
|
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
|
||||||
|
cache[hooks_tuple] = new_hooks
|
||||||
|
c_dict[hooks_key] = new_hooks
|
||||||
|
else:
|
||||||
|
c_dict[hooks_key] = cache[hooks_tuple]
|
||||||
|
|
||||||
|
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||||
|
c = []
|
||||||
|
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||||
|
for t in conditioning:
|
||||||
|
n = [t[0], t[1].copy()]
|
||||||
|
for k in values:
|
||||||
|
if append_hooks and k == 'hooks':
|
||||||
|
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
||||||
|
else:
|
||||||
|
n[1][k] = values[k]
|
||||||
|
c.append(n)
|
||||||
|
|
||||||
|
return c
|
||||||
|
|
||||||
|
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
||||||
|
if hooks is None:
|
||||||
|
return cond
|
||||||
|
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
||||||
|
|
||||||
|
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||||
|
if timestep_range is None:
|
||||||
|
return cond
|
||||||
|
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
|
||||||
|
"end_percent": timestep_range[1]})
|
||||||
|
|
||||||
|
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
|
||||||
|
if mask is None:
|
||||||
|
return cond
|
||||||
|
set_area_to_bounds = False
|
||||||
|
if set_cond_area != 'default':
|
||||||
|
set_area_to_bounds = True
|
||||||
|
if len(mask.shape) < 3:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
return conditioning_set_values(cond, {'mask': mask,
|
||||||
|
'set_area_to_bounds': set_area_to_bounds,
|
||||||
|
'mask_strength': strength})
|
||||||
|
|
||||||
|
def combine_conditioning(conds: list):
|
||||||
|
combined_conds = []
|
||||||
|
for cond in conds:
|
||||||
|
combined_conds.extend(cond)
|
||||||
|
return combined_conds
|
||||||
|
|
||||||
|
def combine_with_new_conds(conds: list, new_conds: list):
|
||||||
|
combined_conds = []
|
||||||
|
for c, new_c in zip(conds, new_conds):
|
||||||
|
combined_conds.append(combine_conditioning([c, new_c]))
|
||||||
|
return combined_conds
|
||||||
|
|
||||||
|
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||||
|
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
|
final_conds = []
|
||||||
|
for c in conds:
|
||||||
|
# first, apply lora_hook to conditioning, if provided
|
||||||
|
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
||||||
|
# next, apply mask to conditioning
|
||||||
|
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||||
|
# apply timesteps, if present
|
||||||
|
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
|
||||||
|
# finally, apply mask to conditioning and store
|
||||||
|
final_conds.append(c)
|
||||||
|
return final_conds
|
||||||
|
|
||||||
|
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||||
|
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
|
combined_conds = []
|
||||||
|
for c, masked_c in zip(conds, new_conds):
|
||||||
|
# first, apply lora_hook to new conditioning, if provided
|
||||||
|
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
||||||
|
# next, apply mask to new conditioning, if provided
|
||||||
|
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||||
|
# apply timesteps, if present
|
||||||
|
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
|
||||||
|
# finally, combine with existing conditioning and store
|
||||||
|
combined_conds.append(combine_conditioning([c, masked_c]))
|
||||||
|
return combined_conds
|
||||||
|
|
||||||
|
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||||
|
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
|
combined_conds = []
|
||||||
|
for c, new_c in zip(conds, new_conds):
|
||||||
|
# first, apply lora_hook to new conditioning, if provided
|
||||||
|
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
||||||
|
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||||
|
new_c = conditioning_set_values(new_c, {'default': True})
|
||||||
|
# apply timesteps, if present
|
||||||
|
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
|
||||||
|
# finally, combine with existing conditioning and store
|
||||||
|
combined_conds.append(combine_conditioning([c, new_c]))
|
||||||
|
return combined_conds
|
@ -15,6 +15,7 @@ from .util import (
|
|||||||
)
|
)
|
||||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||||
from comfy.ldm.util import exists
|
from comfy.ldm.util import exists
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
@ -47,6 +48,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
|||||||
elif isinstance(layer, Upsample):
|
elif isinstance(layer, Upsample):
|
||||||
x = layer(x, output_shape=output_shape)
|
x = layer(x, output_shape=output_shape)
|
||||||
else:
|
else:
|
||||||
|
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||||
|
found_patched = False
|
||||||
|
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||||
|
if isinstance(layer, class_type):
|
||||||
|
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||||
|
found_patched = True
|
||||||
|
break
|
||||||
|
if found_patched:
|
||||||
|
continue
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -819,6 +829,13 @@ class UNetModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
:param x: an [N x C x ...] Tensor of inputs.
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
|
@ -33,7 +33,7 @@ LORA_CLIP_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_lora(lora, to_load):
|
def load_lora(lora, to_load, log_missing=True):
|
||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
@ -213,6 +213,7 @@ def load_lora(lora, to_load):
|
|||||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||||
loaded_keys.add(set_weight_name)
|
loaded_keys.add(set_weight_name)
|
||||||
|
|
||||||
|
if log_missing:
|
||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
if x not in loaded_keys:
|
if x not in loaded_keys:
|
||||||
logging.warning("lora key not loaded: {}".format(x))
|
logging.warning("lora key not loaded: {}".format(x))
|
||||||
@ -429,7 +430,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
|||||||
|
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
v = p[1]
|
v = p[1]
|
||||||
@ -471,6 +472,11 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
elif patch_type == "set":
|
elif patch_type == "set":
|
||||||
weight.copy_(v[0])
|
weight.copy_(v[0])
|
||||||
|
elif patch_type == "model_as_lora":
|
||||||
|
target_weight: torch.Tensor = v[0]
|
||||||
|
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||||
|
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||||
|
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
|
@ -33,12 +33,16 @@ import comfy.ldm.flux.model
|
|||||||
import comfy.ldm.lightricks.model
|
import comfy.ldm.lightricks.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import math
|
import math
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
@ -95,6 +99,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.current_patcher: 'ModelPatcher' = None
|
||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
@ -120,6 +125,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._apply_model,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
|
||||||
|
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
sigma = t
|
sigma = t
|
||||||
xc = self.model_sampling.calculate_input(sigma, x)
|
xc = self.model_sampling.calculate_input(sigma, x)
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
|
@ -16,6 +16,8 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Optional, Callable
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
@ -28,6 +30,9 @@ import comfy.utils
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
|
import comfy.hooks
|
||||||
|
import comfy.patcher_extension
|
||||||
|
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
@ -76,6 +81,17 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
model_options["disable_cfg1_optimization"] = True
|
model_options["disable_cfg1_optimization"] = True
|
||||||
return model_options
|
return model_options
|
||||||
|
|
||||||
|
def create_model_options_clone(orig_model_options: dict):
|
||||||
|
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||||
|
|
||||||
|
def create_hook_patches_clone(orig_hook_patches):
|
||||||
|
new_hook_patches = {}
|
||||||
|
for hook_ref in orig_hook_patches:
|
||||||
|
new_hook_patches[hook_ref] = {}
|
||||||
|
for k in orig_hook_patches[hook_ref]:
|
||||||
|
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||||
|
return new_hook_patches
|
||||||
|
|
||||||
def wipe_lowvram_weight(m):
|
def wipe_lowvram_weight(m):
|
||||||
if hasattr(m, "prev_comfy_cast_weights"):
|
if hasattr(m, "prev_comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||||
@ -119,6 +135,49 @@ def get_key_weight(model, key):
|
|||||||
|
|
||||||
return weight, set_func, convert_func
|
return weight, set_func, convert_func
|
||||||
|
|
||||||
|
class AutoPatcherEjector:
|
||||||
|
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
|
||||||
|
self.model = model
|
||||||
|
self.was_injected = False
|
||||||
|
self.prev_skip_injection = False
|
||||||
|
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.was_injected = False
|
||||||
|
self.prev_skip_injection = self.model.skip_injection
|
||||||
|
if self.skip_and_inject_on_exit_only:
|
||||||
|
self.model.skip_injection = True
|
||||||
|
if self.model.is_injected:
|
||||||
|
self.model.eject_model()
|
||||||
|
self.was_injected = True
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
if self.skip_and_inject_on_exit_only:
|
||||||
|
self.model.skip_injection = self.prev_skip_injection
|
||||||
|
self.model.inject_model()
|
||||||
|
if self.was_injected and not self.model.skip_injection:
|
||||||
|
self.model.inject_model()
|
||||||
|
self.model.skip_injection = self.prev_skip_injection
|
||||||
|
|
||||||
|
class MemoryCounter:
|
||||||
|
def __init__(self, initial: int, minimum=0):
|
||||||
|
self.value = initial
|
||||||
|
self.minimum = minimum
|
||||||
|
# TODO: add a safe limit besides 0
|
||||||
|
|
||||||
|
def use(self, weight: torch.Tensor):
|
||||||
|
weight_size = weight.nelement() * weight.element_size()
|
||||||
|
if self.is_useable(weight_size):
|
||||||
|
self.decrement(weight_size)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_useable(self, used: int):
|
||||||
|
return self.value - used > self.minimum
|
||||||
|
|
||||||
|
def decrement(self, used: int):
|
||||||
|
self.value -= used
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -141,6 +200,24 @@ class ModelPatcher:
|
|||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
self.parent = None
|
self.parent = None
|
||||||
|
|
||||||
|
self.attachments: dict[str] = {}
|
||||||
|
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||||
|
self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks()
|
||||||
|
self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers()
|
||||||
|
|
||||||
|
self.is_injected = False
|
||||||
|
self.skip_injection = False
|
||||||
|
self.injections: dict[str, list[PatcherInjection]] = {}
|
||||||
|
|
||||||
|
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
||||||
|
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
|
||||||
|
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
||||||
|
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
||||||
|
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
||||||
|
self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time
|
||||||
|
self.is_clip = False
|
||||||
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||||
|
|
||||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
@ -177,6 +254,47 @@ class ModelPatcher:
|
|||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
n.parent = self
|
n.parent = self
|
||||||
|
|
||||||
|
# attachments
|
||||||
|
n.attachments = {}
|
||||||
|
for k in self.attachments:
|
||||||
|
if hasattr(self.attachments[k], "on_model_patcher_clone"):
|
||||||
|
n.attachments[k] = self.attachments[k].on_model_patcher_clone()
|
||||||
|
else:
|
||||||
|
n.attachments[k] = self.attachments[k]
|
||||||
|
# additional models
|
||||||
|
for k, c in self.additional_models.items():
|
||||||
|
n.additional_models[k] = [x.clone() for x in c]
|
||||||
|
# callbacks
|
||||||
|
for k, c in self.callbacks.items():
|
||||||
|
n.callbacks[k] = {}
|
||||||
|
for k1, c1 in c.items():
|
||||||
|
n.callbacks[k][k1] = c1.copy()
|
||||||
|
# sample wrappers
|
||||||
|
for k, w in self.wrappers.items():
|
||||||
|
n.wrappers[k] = {}
|
||||||
|
for k1, w1 in w.items():
|
||||||
|
n.wrappers[k][k1] = w1.copy()
|
||||||
|
# injection
|
||||||
|
n.is_injected = self.is_injected
|
||||||
|
n.skip_injection = self.skip_injection
|
||||||
|
for k, i in self.injections.items():
|
||||||
|
n.injections[k] = i.copy()
|
||||||
|
# hooks
|
||||||
|
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
||||||
|
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
|
||||||
|
for group in self.cached_hook_patches:
|
||||||
|
n.cached_hook_patches[group] = {}
|
||||||
|
for k in self.cached_hook_patches[group]:
|
||||||
|
n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k]
|
||||||
|
n.hook_backup = self.hook_backup
|
||||||
|
n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
|
||||||
|
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
||||||
|
n.is_clip = self.is_clip
|
||||||
|
n.hook_mode = self.hook_mode
|
||||||
|
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
|
callback(self, n)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def is_clone(self, other):
|
def is_clone(self, other):
|
||||||
@ -184,10 +302,29 @@ class ModelPatcher:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def clone_has_same_weights(self, clone):
|
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||||
if not self.is_clone(clone):
|
if not self.is_clone(clone):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if self.current_hooks != clone.current_hooks:
|
||||||
|
return False
|
||||||
|
if self.forced_hooks != clone.forced_hooks:
|
||||||
|
return False
|
||||||
|
if self.hook_patches.keys() != clone.hook_patches.keys():
|
||||||
|
return False
|
||||||
|
if self.attachments.keys() != clone.attachments.keys():
|
||||||
|
return False
|
||||||
|
if self.additional_models.keys() != clone.additional_models.keys():
|
||||||
|
return False
|
||||||
|
for key in self.callbacks:
|
||||||
|
if len(self.callbacks[key]) != len(clone.callbacks[key]):
|
||||||
|
return False
|
||||||
|
for key in self.wrappers:
|
||||||
|
if len(self.wrappers[key]) != len(clone.wrappers[key]):
|
||||||
|
return False
|
||||||
|
if self.injections.keys() != clone.injections.keys():
|
||||||
|
return False
|
||||||
|
|
||||||
if len(self.patches) == 0 and len(clone.patches) == 0:
|
if len(self.patches) == 0 and len(clone.patches) == 0:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -256,6 +393,12 @@ class ModelPatcher:
|
|||||||
def set_model_output_block_patch(self, patch):
|
def set_model_output_block_patch(self, patch):
|
||||||
self.set_model_patch(patch, "output_block_patch")
|
self.set_model_patch(patch, "output_block_patch")
|
||||||
|
|
||||||
|
def set_model_emb_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "emb_patch")
|
||||||
|
|
||||||
|
def set_model_forward_timestep_embed_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||||
|
|
||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
@ -294,6 +437,7 @@ class ModelPatcher:
|
|||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
|
with self.use_ejected():
|
||||||
p = set()
|
p = set()
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for k in patches:
|
for k in patches:
|
||||||
@ -324,9 +468,12 @@ class ModelPatcher:
|
|||||||
if not k.startswith(filter_prefix):
|
if not k.startswith(filter_prefix):
|
||||||
continue
|
continue
|
||||||
bk = self.backup.get(k, None)
|
bk = self.backup.get(k, None)
|
||||||
|
hbk = self.hook_backup.get(k, None)
|
||||||
weight, set_func, convert_func = get_key_weight(self.model, k)
|
weight, set_func, convert_func = get_key_weight(self.model, k)
|
||||||
if bk is not None:
|
if bk is not None:
|
||||||
weight = bk.weight
|
weight = bk.weight
|
||||||
|
if hbk is not None:
|
||||||
|
weight = hbk[0]
|
||||||
if convert_func is None:
|
if convert_func is None:
|
||||||
convert_func = lambda a, **kwargs: a
|
convert_func = lambda a, **kwargs: a
|
||||||
|
|
||||||
@ -337,6 +484,7 @@ class ModelPatcher:
|
|||||||
return p
|
return p
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
|
with self.use_ejected():
|
||||||
sd = self.model.state_dict()
|
sd = self.model.state_dict()
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
if filter_prefix is not None:
|
if filter_prefix is not None:
|
||||||
@ -388,6 +536,8 @@ class ModelPatcher:
|
|||||||
return loading
|
return loading
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
|
with self.use_ejected():
|
||||||
|
self.unpatch_hooks()
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
@ -471,7 +621,13 @@ class ModelPatcher:
|
|||||||
self.model.model_loaded_weight_memory = mem_counter
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||||
|
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||||
|
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||||
|
|
||||||
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||||
|
|
||||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
|
with self.use_ejected():
|
||||||
for k in self.object_patches:
|
for k in self.object_patches:
|
||||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||||
if k not in self.object_patches_backup:
|
if k not in self.object_patches_backup:
|
||||||
@ -484,10 +640,13 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if load_weights:
|
if load_weights:
|
||||||
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||||
|
self.inject_model()
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
|
self.eject_model()
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
|
self.unpatch_hooks()
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
@ -523,6 +682,7 @@ class ModelPatcher:
|
|||||||
self.object_patches_backup.clear()
|
self.object_patches_backup.clear()
|
||||||
|
|
||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
|
with self.use_ejected():
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
unload_list = self._load_list()
|
unload_list = self._load_list()
|
||||||
@ -576,6 +736,7 @@ class ModelPatcher:
|
|||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
|
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||||
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
||||||
# TODO: force_patch_weights should not unload + reload full model
|
# TODO: force_patch_weights should not unload + reload full model
|
||||||
used = self.model.model_loaded_weight_memory
|
used = self.model.model_loaded_weight_memory
|
||||||
@ -586,6 +747,7 @@ class ModelPatcher:
|
|||||||
self.patch_model(load_weights=False)
|
self.patch_model(load_weights=False)
|
||||||
full_load = False
|
full_load = False
|
||||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||||
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||||
return 0
|
return 0
|
||||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
full_load = True
|
full_load = True
|
||||||
@ -599,9 +761,12 @@ class ModelPatcher:
|
|||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def detach(self, unpatch_all=True):
|
def detach(self, unpatch_all=True):
|
||||||
|
self.eject_model()
|
||||||
self.model_patches_to(self.offload_device)
|
self.model_patches_to(self.offload_device)
|
||||||
if unpatch_all:
|
if unpatch_all:
|
||||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
||||||
|
callback(self, unpatch_all)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
@ -611,6 +776,345 @@ class ModelPatcher:
|
|||||||
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||||
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
self.clean_hooks()
|
||||||
|
if hasattr(self.model, "current_patcher"):
|
||||||
|
self.model.current_patcher = None
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
def add_callback(self, call_type: str, callback: Callable):
|
||||||
|
self.add_callback_with_key(call_type, None, callback)
|
||||||
|
|
||||||
|
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
||||||
|
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||||
|
c.append(callback)
|
||||||
|
|
||||||
|
def remove_callbacks_with_key(self, call_type: str, key: str):
|
||||||
|
c = self.callbacks.get(call_type, {})
|
||||||
|
if key in c:
|
||||||
|
c.pop(key)
|
||||||
|
|
||||||
|
def get_callbacks(self, call_type: str, key: str):
|
||||||
|
return self.callbacks.get(call_type, {}).get(key, [])
|
||||||
|
|
||||||
|
def get_all_callbacks(self, call_type: str):
|
||||||
|
c_list = []
|
||||||
|
for c in self.callbacks.get(call_type, {}).values():
|
||||||
|
c_list.extend(c)
|
||||||
|
return c_list
|
||||||
|
|
||||||
|
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
|
||||||
|
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||||
|
|
||||||
|
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
||||||
|
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||||
|
w.append(wrapper)
|
||||||
|
|
||||||
|
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
||||||
|
w = self.wrappers.get(wrapper_type, {})
|
||||||
|
if key in w:
|
||||||
|
w.pop(key)
|
||||||
|
|
||||||
|
def get_wrappers(self, wrapper_type: str, key: str):
|
||||||
|
return self.wrappers.get(wrapper_type, {}).get(key, [])
|
||||||
|
|
||||||
|
def get_all_wrappers(self, wrapper_type: str):
|
||||||
|
w_list = []
|
||||||
|
for w in self.wrappers.get(wrapper_type, {}).values():
|
||||||
|
w_list.extend(w)
|
||||||
|
return w_list
|
||||||
|
|
||||||
|
def set_attachments(self, key: str, attachment):
|
||||||
|
self.attachments[key] = attachment
|
||||||
|
|
||||||
|
def remove_attachments(self, key: str):
|
||||||
|
if key in self.attachments:
|
||||||
|
self.attachments.pop(key)
|
||||||
|
|
||||||
|
def get_attachment(self, key: str):
|
||||||
|
return self.attachments.get(key, None)
|
||||||
|
|
||||||
|
def set_injections(self, key: str, injections: list[PatcherInjection]):
|
||||||
|
self.injections[key] = injections
|
||||||
|
|
||||||
|
def remove_injections(self, key: str):
|
||||||
|
if key in self.injections:
|
||||||
|
self.injections.pop(key)
|
||||||
|
|
||||||
|
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
||||||
|
self.additional_models[key] = models
|
||||||
|
|
||||||
|
def remove_additional_models(self, key: str):
|
||||||
|
if key in self.additional_models:
|
||||||
|
self.additional_models.pop(key)
|
||||||
|
|
||||||
|
def get_additional_models_with_key(self, key: str):
|
||||||
|
return self.additional_models.get(key, [])
|
||||||
|
|
||||||
|
def get_additional_models(self):
|
||||||
|
all_models = []
|
||||||
|
for models in self.additional_models.values():
|
||||||
|
all_models.extend(models)
|
||||||
|
return all_models
|
||||||
|
|
||||||
|
def get_nested_additional_models(self):
|
||||||
|
def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]):
|
||||||
|
'''Make sure circular references do not cause infinite recursion.'''
|
||||||
|
next_models = []
|
||||||
|
for model in prev_models:
|
||||||
|
candidates = model.get_additional_models()
|
||||||
|
for c in candidates:
|
||||||
|
if c not in cache_set:
|
||||||
|
next_models.append(c)
|
||||||
|
cache_set.add(c)
|
||||||
|
if len(next_models) == 0:
|
||||||
|
return prev_models
|
||||||
|
return prev_models + _evaluate_sub_additional_models(next_models, cache_set)
|
||||||
|
|
||||||
|
all_models = self.get_additional_models()
|
||||||
|
models_set = set(all_models)
|
||||||
|
real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
|
||||||
|
return real_all_models
|
||||||
|
|
||||||
|
def use_ejected(self, skip_and_inject_on_exit_only=False):
|
||||||
|
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)
|
||||||
|
|
||||||
|
def inject_model(self):
|
||||||
|
if self.is_injected or self.skip_injection:
|
||||||
|
return
|
||||||
|
for injections in self.injections.values():
|
||||||
|
for inj in injections:
|
||||||
|
inj.inject(self)
|
||||||
|
self.is_injected = True
|
||||||
|
if self.is_injected:
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
def eject_model(self):
|
||||||
|
if not self.is_injected:
|
||||||
|
return
|
||||||
|
for injections in self.injections.values():
|
||||||
|
for inj in injections:
|
||||||
|
inj.eject(self)
|
||||||
|
self.is_injected = False
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
def pre_run(self):
|
||||||
|
if hasattr(self.model, "current_patcher"):
|
||||||
|
self.model.current_patcher = self
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
def prepare_state(self, timestep):
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||||
|
callback(self, timestep)
|
||||||
|
|
||||||
|
def restore_hook_patches(self):
|
||||||
|
if len(self.hook_patches_backup) > 0:
|
||||||
|
self.hook_patches = self.hook_patches_backup
|
||||||
|
self.hook_patches_backup = {}
|
||||||
|
|
||||||
|
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
||||||
|
self.hook_mode = hook_mode
|
||||||
|
|
||||||
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
|
||||||
|
curr_t = t[0]
|
||||||
|
reset_current_hooks = False
|
||||||
|
for hook in hook_group.hooks:
|
||||||
|
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
|
||||||
|
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||||
|
# this will cause the weights to be recalculated when sampling
|
||||||
|
if changed:
|
||||||
|
# reset current_hooks if contains hook that changed
|
||||||
|
if self.current_hooks is not None:
|
||||||
|
for current_hook in self.current_hooks.hooks:
|
||||||
|
if current_hook == hook:
|
||||||
|
reset_current_hooks = True
|
||||||
|
break
|
||||||
|
for cached_group in list(self.cached_hook_patches.keys()):
|
||||||
|
if cached_group.contains(hook):
|
||||||
|
self.cached_hook_patches.pop(cached_group)
|
||||||
|
if reset_current_hooks:
|
||||||
|
self.patch_hooks(None)
|
||||||
|
|
||||||
|
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
|
||||||
|
self.restore_hook_patches()
|
||||||
|
registered_hooks: list[comfy.hooks.Hook] = []
|
||||||
|
# handle WrapperHooks, if model_options provided
|
||||||
|
if model_options is not None:
|
||||||
|
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
|
||||||
|
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
||||||
|
# handle WeightHooks
|
||||||
|
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||||
|
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
|
||||||
|
if hook.hook_ref not in self.hook_patches:
|
||||||
|
weight_hooks_to_register.append(hook)
|
||||||
|
if len(weight_hooks_to_register) > 0:
|
||||||
|
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
||||||
|
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
||||||
|
for hook in weight_hooks_to_register:
|
||||||
|
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
||||||
|
callback(self, hooks_dict, target)
|
||||||
|
|
||||||
|
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
|
with self.use_ejected():
|
||||||
|
# NOTE: this mirrors behavior of add_patches func
|
||||||
|
current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {})
|
||||||
|
p = set()
|
||||||
|
model_sd = self.model.state_dict()
|
||||||
|
for k in patches:
|
||||||
|
offset = None
|
||||||
|
function = None
|
||||||
|
if isinstance(k, str):
|
||||||
|
key = k
|
||||||
|
else:
|
||||||
|
offset = k[1]
|
||||||
|
key = k[0]
|
||||||
|
if len(k) > 2:
|
||||||
|
function = k[2]
|
||||||
|
|
||||||
|
if key in model_sd:
|
||||||
|
p.add(k)
|
||||||
|
current_patches: list[tuple] = current_hook_patches.get(key, [])
|
||||||
|
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||||
|
current_hook_patches[key] = current_patches
|
||||||
|
self.hook_patches[hook.hook_ref] = current_hook_patches
|
||||||
|
# since should care about these patches too to determine if same model, reroll patches_uuid
|
||||||
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
return list(p)
|
||||||
|
|
||||||
|
def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup):
|
||||||
|
# combined_patches will contain weights of all relevant hooks, per key
|
||||||
|
combined_patches = {}
|
||||||
|
if hooks is not None:
|
||||||
|
for hook in hooks.hooks:
|
||||||
|
hook_patches: dict = self.hook_patches.get(hook.hook_ref, {})
|
||||||
|
for key in hook_patches.keys():
|
||||||
|
current_patches: list[tuple] = combined_patches.get(key, [])
|
||||||
|
if math.isclose(hook.strength, 1.0):
|
||||||
|
current_patches.extend(hook_patches[key])
|
||||||
|
else:
|
||||||
|
# patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model)
|
||||||
|
for patch in hook_patches[key]:
|
||||||
|
new_patch = list(patch)
|
||||||
|
new_patch[0] *= hook.strength
|
||||||
|
current_patches.append(tuple(new_patch))
|
||||||
|
combined_patches[key] = current_patches
|
||||||
|
return combined_patches
|
||||||
|
|
||||||
|
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
||||||
|
# TODO: return transformer_options dict with any additions from hooks
|
||||||
|
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||||
|
return {}
|
||||||
|
self.patch_hooks(hooks=hooks)
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||||
|
callback(self, hooks)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
|
with self.use_ejected():
|
||||||
|
self.unpatch_hooks()
|
||||||
|
if hooks is not None:
|
||||||
|
model_sd_keys = list(self.model_state_dict().keys())
|
||||||
|
memory_counter = None
|
||||||
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||||
|
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
|
||||||
|
memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
|
||||||
|
minimum=comfy.model_management.minimum_inference_memory()*2)
|
||||||
|
# if have cached weights for hooks, use it
|
||||||
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||||
|
if cached_weights is not None:
|
||||||
|
for key in cached_weights:
|
||||||
|
if key not in model_sd_keys:
|
||||||
|
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||||
|
continue
|
||||||
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||||
|
else:
|
||||||
|
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||||
|
original_weights = None
|
||||||
|
if len(relevant_patches) > 0:
|
||||||
|
original_weights = self.get_key_patches()
|
||||||
|
for key in relevant_patches:
|
||||||
|
if key not in model_sd_keys:
|
||||||
|
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||||
|
continue
|
||||||
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||||
|
memory_counter=memory_counter)
|
||||||
|
self.current_hooks = hooks
|
||||||
|
|
||||||
|
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||||
|
if key not in self.hook_backup:
|
||||||
|
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
|
||||||
|
target_device = self.offload_device
|
||||||
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||||
|
used = memory_counter.use(weight)
|
||||||
|
if used:
|
||||||
|
target_device = weight.device
|
||||||
|
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
||||||
|
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
||||||
|
|
||||||
|
def clear_cached_hook_weights(self):
|
||||||
|
self.cached_hook_patches.clear()
|
||||||
|
self.patch_hooks(None)
|
||||||
|
|
||||||
|
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||||
|
if key not in combined_patches:
|
||||||
|
return
|
||||||
|
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
weight: torch.Tensor
|
||||||
|
if key not in self.hook_backup:
|
||||||
|
target_device = self.offload_device
|
||||||
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||||
|
used = memory_counter.use(weight)
|
||||||
|
if used:
|
||||||
|
target_device = weight.device
|
||||||
|
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
||||||
|
# TODO: properly handle LowVramPatch, if it ends up an issue
|
||||||
|
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
|
||||||
|
if convert_func is not None:
|
||||||
|
temp_weight = convert_func(temp_weight, inplace=True)
|
||||||
|
|
||||||
|
out_weight = comfy.lora.calculate_weight(combined_patches[key],
|
||||||
|
temp_weight,
|
||||||
|
key, original_weights=original_weights)
|
||||||
|
del original_weights[key]
|
||||||
|
if set_func is None:
|
||||||
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
||||||
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||||
|
# TODO: disable caching if not enough system RAM to do so
|
||||||
|
target_device = self.offload_device
|
||||||
|
used = memory_counter.use(weight)
|
||||||
|
if used:
|
||||||
|
target_device = weight.device
|
||||||
|
self.cached_hook_patches.setdefault(hooks, {})
|
||||||
|
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
|
||||||
|
del temp_weight
|
||||||
|
del out_weight
|
||||||
|
del weight
|
||||||
|
|
||||||
|
def unpatch_hooks(self) -> None:
|
||||||
|
with self.use_ejected():
|
||||||
|
if len(self.hook_backup) == 0:
|
||||||
|
self.current_hooks = None
|
||||||
|
return
|
||||||
|
keys = list(self.hook_backup.keys())
|
||||||
|
for k in keys:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
|
||||||
|
self.hook_backup.clear()
|
||||||
|
self.current_hooks = None
|
||||||
|
|
||||||
|
def clean_hooks(self):
|
||||||
|
self.unpatch_hooks()
|
||||||
|
self.clear_cached_hook_weights()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.detach(unpatch_all=False)
|
self.detach(unpatch_all=False)
|
||||||
|
|
||||||
|
156
comfy/patcher_extension.py
Normal file
156
comfy/patcher_extension.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
class CallbacksMP:
|
||||||
|
ON_CLONE = "on_clone"
|
||||||
|
ON_LOAD = "on_load_after"
|
||||||
|
ON_DETACH = "on_detach_after"
|
||||||
|
ON_CLEANUP = "on_cleanup"
|
||||||
|
ON_PRE_RUN = "on_pre_run"
|
||||||
|
ON_PREPARE_STATE = "on_prepare_state"
|
||||||
|
ON_APPLY_HOOKS = "on_apply_hooks"
|
||||||
|
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
|
||||||
|
ON_INJECT_MODEL = "on_inject_model"
|
||||||
|
ON_EJECT_MODEL = "on_eject_model"
|
||||||
|
|
||||||
|
# callbacks dict is in the format:
|
||||||
|
# {"call_type": {"key": [Callable1, Callable2, ...]} }
|
||||||
|
@classmethod
|
||||||
|
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||||
|
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
|
||||||
|
|
||||||
|
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||||
|
if is_model_options:
|
||||||
|
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||||
|
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
|
||||||
|
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||||
|
c.append(callback)
|
||||||
|
|
||||||
|
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||||
|
if is_model_options:
|
||||||
|
transformer_options = transformer_options.get("transformer_options", {})
|
||||||
|
c_list = []
|
||||||
|
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||||
|
c_list.extend(callbacks.get(call_type, {}).get(key, []))
|
||||||
|
return c_list
|
||||||
|
|
||||||
|
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
|
||||||
|
if is_model_options:
|
||||||
|
transformer_options = transformer_options.get("transformer_options", {})
|
||||||
|
c_list = []
|
||||||
|
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||||
|
for c in callbacks.get(call_type, {}).values():
|
||||||
|
c_list.extend(c)
|
||||||
|
return c_list
|
||||||
|
|
||||||
|
class WrappersMP:
|
||||||
|
OUTER_SAMPLE = "outer_sample"
|
||||||
|
SAMPLER_SAMPLE = "sampler_sample"
|
||||||
|
CALC_COND_BATCH = "calc_cond_batch"
|
||||||
|
APPLY_MODEL = "apply_model"
|
||||||
|
DIFFUSION_MODEL = "diffusion_model"
|
||||||
|
|
||||||
|
# wrappers dict is in the format:
|
||||||
|
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
|
||||||
|
@classmethod
|
||||||
|
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||||
|
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
|
||||||
|
|
||||||
|
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||||
|
if is_model_options:
|
||||||
|
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||||
|
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
|
||||||
|
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||||
|
w.append(wrapper)
|
||||||
|
|
||||||
|
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||||
|
if is_model_options:
|
||||||
|
transformer_options = transformer_options.get("transformer_options", {})
|
||||||
|
w_list = []
|
||||||
|
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||||
|
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
|
||||||
|
return w_list
|
||||||
|
|
||||||
|
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
|
||||||
|
if is_model_options:
|
||||||
|
transformer_options = transformer_options.get("transformer_options", {})
|
||||||
|
w_list = []
|
||||||
|
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||||
|
for w in wrappers.get(wrapper_type, {}).values():
|
||||||
|
w_list.extend(w)
|
||||||
|
return w_list
|
||||||
|
|
||||||
|
class WrapperExecutor:
|
||||||
|
"""Handles call stack of wrappers around a function in an ordered manner."""
|
||||||
|
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
|
||||||
|
# NOTE: class_obj exists so that wrappers surrounding a class method can access
|
||||||
|
# the class instance at runtime via executor.class_obj
|
||||||
|
self.original = original
|
||||||
|
self.class_obj = class_obj
|
||||||
|
self.wrappers = wrappers.copy()
|
||||||
|
self.idx = idx
|
||||||
|
self.is_last = idx == len(wrappers)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
"""Calls the next wrapper or original function, whichever is appropriate."""
|
||||||
|
new_executor = self._create_next_executor()
|
||||||
|
return new_executor.execute(*args, **kwargs)
|
||||||
|
|
||||||
|
def execute(self, *args, **kwargs):
|
||||||
|
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
||||||
|
args = list(args)
|
||||||
|
kwargs = dict(kwargs)
|
||||||
|
if self.is_last:
|
||||||
|
return self.original(*args, **kwargs)
|
||||||
|
return self.wrappers[self.idx](self, *args, **kwargs)
|
||||||
|
|
||||||
|
def _create_next_executor(self) -> 'WrapperExecutor':
|
||||||
|
new_idx = self.idx + 1
|
||||||
|
if new_idx > len(self.wrappers):
|
||||||
|
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
|
||||||
|
if self.class_obj is None:
|
||||||
|
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
||||||
|
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
||||||
|
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
||||||
|
return cls(original, class_obj, wrappers, idx=idx)
|
||||||
|
|
||||||
|
class PatcherInjection:
|
||||||
|
def __init__(self, inject: Callable, eject: Callable):
|
||||||
|
self.inject = inject
|
||||||
|
self.eject = eject
|
||||||
|
|
||||||
|
def copy_nested_dicts(input_dict: dict):
|
||||||
|
new_dict = input_dict.copy()
|
||||||
|
for key, value in input_dict.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
new_dict[key] = copy_nested_dicts(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
new_dict[key] = value.copy()
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
||||||
|
if copy_dict1:
|
||||||
|
merged_dict = copy_nested_dicts(dict1)
|
||||||
|
else:
|
||||||
|
merged_dict = dict1
|
||||||
|
for key, value in dict2.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
curr_value = merged_dict.setdefault(key, {})
|
||||||
|
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
merged_dict.setdefault(key, []).extend(value)
|
||||||
|
else:
|
||||||
|
merged_dict[key] = value
|
||||||
|
return merged_dict
|
@ -1,7 +1,16 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import uuid
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.hooks
|
||||||
|
import comfy.patcher_extension
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
|
from comfy.controlnet import ControlBase
|
||||||
|
|
||||||
def prepare_mask(noise_mask, shape, device):
|
def prepare_mask(noise_mask, shape, device):
|
||||||
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
|
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
|
||||||
@ -10,9 +19,43 @@ def get_models_from_cond(cond, model_type):
|
|||||||
models = []
|
models = []
|
||||||
for c in cond:
|
for c in cond:
|
||||||
if model_type in c:
|
if model_type in c:
|
||||||
|
if isinstance(c[model_type], list):
|
||||||
|
models += c[model_type]
|
||||||
|
else:
|
||||||
models += [c[model_type]]
|
models += [c[model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
||||||
|
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||||
|
cnets: list[ControlBase] = []
|
||||||
|
for c in cond:
|
||||||
|
if 'hooks' in c:
|
||||||
|
for hook in c['hooks'].hooks:
|
||||||
|
hook: comfy.hooks.Hook
|
||||||
|
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||||
|
with_type[hook] = None
|
||||||
|
if 'control' in c:
|
||||||
|
cnets.append(c['control'])
|
||||||
|
|
||||||
|
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
|
||||||
|
if cnet.extra_hooks is not None:
|
||||||
|
_list.append(cnet.extra_hooks)
|
||||||
|
if cnet.previous_controlnet is None:
|
||||||
|
return _list
|
||||||
|
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
|
||||||
|
|
||||||
|
hooks_list = []
|
||||||
|
cnets = set(cnets)
|
||||||
|
for base_cnet in cnets:
|
||||||
|
get_extra_hooks_from_cnet(base_cnet, hooks_list)
|
||||||
|
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||||
|
if extra_hooks is not None:
|
||||||
|
for hook in extra_hooks.hooks:
|
||||||
|
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||||
|
with_type[hook] = None
|
||||||
|
|
||||||
|
return hooks_dict
|
||||||
|
|
||||||
def convert_cond(cond):
|
def convert_cond(cond):
|
||||||
out = []
|
out = []
|
||||||
for c in cond:
|
for c in cond:
|
||||||
@ -22,17 +65,22 @@ def convert_cond(cond):
|
|||||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||||
temp["cross_attn"] = c[0]
|
temp["cross_attn"] = c[0]
|
||||||
temp["model_conds"] = model_conds
|
temp["model_conds"] = model_conds
|
||||||
|
temp["uuid"] = uuid.uuid4()
|
||||||
out.append(temp)
|
out.append(temp)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_additional_models(conds, dtype):
|
def get_additional_models(conds, dtype):
|
||||||
"""loads additional models in conditioning"""
|
"""loads additional models in conditioning"""
|
||||||
cnets = []
|
cnets: list[ControlBase] = []
|
||||||
gligen = []
|
gligen = []
|
||||||
|
add_models = []
|
||||||
|
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
||||||
|
|
||||||
for k in conds:
|
for k in conds:
|
||||||
cnets += get_models_from_cond(conds[k], "control")
|
cnets += get_models_from_cond(conds[k], "control")
|
||||||
gligen += get_models_from_cond(conds[k], "gligen")
|
gligen += get_models_from_cond(conds[k], "gligen")
|
||||||
|
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||||
|
get_hooks_from_cond(conds[k], hooks)
|
||||||
|
|
||||||
control_nets = set(cnets)
|
control_nets = set(cnets)
|
||||||
|
|
||||||
@ -43,7 +91,9 @@ def get_additional_models(conds, dtype):
|
|||||||
inference_memory += m.inference_memory_requirements(dtype)
|
inference_memory += m.inference_memory_requirements(dtype)
|
||||||
|
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1] for x in gligen]
|
||||||
models = control_models + gligen
|
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
||||||
|
models = control_models + gligen + add_models + hook_models
|
||||||
|
|
||||||
return models, inference_memory
|
return models, inference_memory
|
||||||
|
|
||||||
def cleanup_additional_models(models):
|
def cleanup_additional_models(models):
|
||||||
@ -53,10 +103,11 @@ def cleanup_additional_models(models):
|
|||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def prepare_sampling(model, noise_shape, conds):
|
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
||||||
device = model.load_device
|
device = model.load_device
|
||||||
real_model = None
|
real_model: 'BaseModel' = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||||
@ -72,3 +123,14 @@ def cleanup_models(conds, models):
|
|||||||
control_cleanup += get_models_from_cond(conds[k], "control")
|
control_cleanup += get_models_from_cond(conds[k], "control")
|
||||||
|
|
||||||
cleanup_additional_models(set(control_cleanup))
|
cleanup_additional_models(set(control_cleanup))
|
||||||
|
|
||||||
|
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||||
|
# check for hooks in conds - if not registered, see if can be applied
|
||||||
|
hooks = {}
|
||||||
|
for k in conds:
|
||||||
|
get_hooks_from_cond(conds[k], hooks)
|
||||||
|
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||||
|
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||||
|
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||||
|
# register hooks on model/model_options
|
||||||
|
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
||||||
|
@ -1,11 +1,21 @@
|
|||||||
|
from __future__ import annotations
|
||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
|
from comfy.controlnet import ControlBase
|
||||||
import torch
|
import torch
|
||||||
import collections
|
import collections
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
import comfy.samplers
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.patcher_extension
|
||||||
|
import comfy.hooks
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
@ -70,6 +80,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
for c in model_conds:
|
for c in model_conds:
|
||||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||||
|
|
||||||
|
hooks = conds.get('hooks', None)
|
||||||
control = conds.get('control', None)
|
control = conds.get('control', None)
|
||||||
|
|
||||||
patches = None
|
patches = None
|
||||||
@ -85,8 +96,8 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
|
|
||||||
patches['middle_patch'] = [gligen_patch]
|
patches['middle_patch'] = [gligen_patch]
|
||||||
|
|
||||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
|
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
|
||||||
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
|
||||||
|
|
||||||
def cond_equal_size(c1, c2):
|
def cond_equal_size(c1, c2):
|
||||||
if c1 is c2:
|
if c1 is c2:
|
||||||
@ -138,24 +149,92 @@ def cond_cat(c_list):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
|
||||||
|
# need to figure out remaining unmasked area for conds
|
||||||
|
default_mults = []
|
||||||
|
for _ in default_conds:
|
||||||
|
default_mults.append(torch.ones_like(x_in))
|
||||||
|
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
|
||||||
|
for lora_hooks, to_run in hooked_to_run.items():
|
||||||
|
for cond_obj, i in to_run:
|
||||||
|
# if no default_cond for cond_type, do nothing
|
||||||
|
if len(default_conds[i]) == 0:
|
||||||
|
continue
|
||||||
|
area: list[int] = cond_obj.area
|
||||||
|
if area is not None:
|
||||||
|
curr_default_mult: torch.Tensor = default_mults[i]
|
||||||
|
dims = len(area) // 2
|
||||||
|
for i in range(dims):
|
||||||
|
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
|
||||||
|
curr_default_mult -= cond_obj.mult
|
||||||
|
else:
|
||||||
|
default_mults[i] -= cond_obj.mult
|
||||||
|
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
|
||||||
|
for i, mult in enumerate(default_mults):
|
||||||
|
# if no default_cond for cond type, do nothing
|
||||||
|
if len(default_conds[i]) == 0:
|
||||||
|
continue
|
||||||
|
torch.nn.functional.relu(mult, inplace=True)
|
||||||
|
# if mult is all zeros, then don't add default_cond
|
||||||
|
if torch.max(mult) == 0.0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cond = default_conds[i]
|
||||||
|
for x in cond:
|
||||||
|
# do get_area_and_mult to get all the expected values
|
||||||
|
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||||
|
if p is None:
|
||||||
|
continue
|
||||||
|
# replace p's mult with calculated mult
|
||||||
|
p = p._replace(mult=mult)
|
||||||
|
if p.hooks is not None:
|
||||||
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||||
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
|
|
||||||
|
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||||
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
|
_calc_cond_batch,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||||
|
)
|
||||||
|
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
|
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||||
out_conds = []
|
out_conds = []
|
||||||
out_counts = []
|
out_counts = []
|
||||||
to_run = []
|
# separate conds by matching hooks
|
||||||
|
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||||
|
default_conds = []
|
||||||
|
has_default_conds = False
|
||||||
|
|
||||||
for i in range(len(conds)):
|
for i in range(len(conds)):
|
||||||
out_conds.append(torch.zeros_like(x_in))
|
out_conds.append(torch.zeros_like(x_in))
|
||||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||||
|
|
||||||
cond = conds[i]
|
cond = conds[i]
|
||||||
|
default_c = []
|
||||||
if cond is not None:
|
if cond is not None:
|
||||||
for x in cond:
|
for x in cond:
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
if 'default' in x:
|
||||||
|
default_c.append(x)
|
||||||
|
has_default_conds = True
|
||||||
|
continue
|
||||||
|
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
if p.hooks is not None:
|
||||||
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||||
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
|
default_conds.append(default_c)
|
||||||
|
|
||||||
to_run += [(p, i)]
|
if has_default_conds:
|
||||||
|
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
|
||||||
|
|
||||||
|
model.current_patcher.prepare_state(timestep)
|
||||||
|
|
||||||
|
# run every hooked_to_run separately
|
||||||
|
for hooks, to_run in hooked_to_run.items():
|
||||||
while len(to_run) > 0:
|
while len(to_run) > 0:
|
||||||
first = to_run[0]
|
first = to_run[0]
|
||||||
first_shape = first[0][0].shape
|
first_shape = first[0][0].shape
|
||||||
@ -179,6 +258,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
|||||||
mult = []
|
mult = []
|
||||||
c = []
|
c = []
|
||||||
cond_or_uncond = []
|
cond_or_uncond = []
|
||||||
|
uuids = []
|
||||||
area = []
|
area = []
|
||||||
control = None
|
control = None
|
||||||
patches = None
|
patches = None
|
||||||
@ -190,6 +270,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
|||||||
c.append(p.conditioning)
|
c.append(p.conditioning)
|
||||||
area.append(p.area)
|
area.append(p.area)
|
||||||
cond_or_uncond.append(o[1])
|
cond_or_uncond.append(o[1])
|
||||||
|
uuids.append(p.uuid)
|
||||||
control = p.control
|
control = p.control
|
||||||
patches = p.patches
|
patches = p.patches
|
||||||
|
|
||||||
@ -198,14 +279,14 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
|||||||
c = cond_cat(c)
|
c = cond_cat(c)
|
||||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||||
|
|
||||||
if control is not None:
|
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
|
||||||
|
|
||||||
transformer_options = {}
|
|
||||||
if 'transformer_options' in model_options:
|
if 'transformer_options' in model_options:
|
||||||
transformer_options = model_options['transformer_options'].copy()
|
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||||
|
model_options['transformer_options'],
|
||||||
|
copy_dict1=False)
|
||||||
|
|
||||||
if patches is not None:
|
if patches is not None:
|
||||||
|
# TODO: replace with merge_nested_dicts function
|
||||||
if "patches" in transformer_options:
|
if "patches" in transformer_options:
|
||||||
cur_patches = transformer_options["patches"].copy()
|
cur_patches = transformer_options["patches"].copy()
|
||||||
for p in patches:
|
for p in patches:
|
||||||
@ -218,10 +299,14 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
|||||||
transformer_options["patches"] = patches
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
|
transformer_options["uuids"] = uuids[:]
|
||||||
transformer_options["sigmas"] = timestep
|
transformer_options["sigmas"] = timestep
|
||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
|
if control is not None:
|
||||||
|
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||||
|
|
||||||
if 'model_function_wrapper' in model_options:
|
if 'model_function_wrapper' in model_options:
|
||||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
else:
|
else:
|
||||||
@ -500,6 +585,11 @@ def calculate_start_end_timesteps(model, conds):
|
|||||||
|
|
||||||
timestep_start = None
|
timestep_start = None
|
||||||
timestep_end = None
|
timestep_end = None
|
||||||
|
# handle clip hook schedule, if needed
|
||||||
|
if 'clip_start_percent' in x:
|
||||||
|
timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
|
||||||
|
timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
|
||||||
|
else:
|
||||||
if 'start_percent' in x:
|
if 'start_percent' in x:
|
||||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||||
if 'end_percent' in x:
|
if 'end_percent' in x:
|
||||||
@ -673,6 +763,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
|||||||
if k != kk:
|
if k != kk:
|
||||||
create_cond_with_same_area_if_none(conds[kk], c)
|
create_cond_with_same_area_if_none(conds[kk], c)
|
||||||
|
|
||||||
|
for k in conds:
|
||||||
|
for c in conds[k]:
|
||||||
|
if 'hooks' in c:
|
||||||
|
for hook in c['hooks'].hooks:
|
||||||
|
hook.initialize_timesteps(model)
|
||||||
|
|
||||||
for k in conds:
|
for k in conds:
|
||||||
pre_run_control(model, conds[k])
|
pre_run_control(model, conds[k])
|
||||||
|
|
||||||
@ -685,9 +781,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
|||||||
|
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
||||||
|
# determine which ControlNets have extra_hooks that should be combined with normal hooks
|
||||||
|
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
if 'control' in kk:
|
||||||
|
control: 'ControlBase' = kk['control']
|
||||||
|
extra_hooks = control.get_extra_hooks()
|
||||||
|
if len(extra_hooks) > 0:
|
||||||
|
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||||
|
to_replace = hook_replacement.setdefault((control, hooks), [])
|
||||||
|
to_replace.append(kk)
|
||||||
|
# if nothing to replace, do nothing
|
||||||
|
if len(hook_replacement) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
|
||||||
|
# on the cond dicts
|
||||||
|
for key, conds_to_modify in hook_replacement.items():
|
||||||
|
control = key[0]
|
||||||
|
hooks = key[1]
|
||||||
|
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
|
||||||
|
# if combined hooks are not None, set as new hooks for all relevant conds
|
||||||
|
if hooks is not None:
|
||||||
|
for cond in conds_to_modify:
|
||||||
|
cond['hooks'] = hooks
|
||||||
|
|
||||||
|
|
||||||
|
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||||
|
hooks_set = set()
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
hooks_set.add(kk.get('hooks', None))
|
||||||
|
return len(hooks_set)
|
||||||
|
|
||||||
|
|
||||||
class CFGGuider:
|
class CFGGuider:
|
||||||
def __init__(self, model_patcher):
|
def __init__(self, model_patcher):
|
||||||
self.model_patcher = model_patcher
|
self.model_patcher: 'ModelPatcher' = model_patcher
|
||||||
self.model_options = model_patcher.model_options
|
self.model_options = model_patcher.model_options
|
||||||
self.original_conds = {}
|
self.original_conds = {}
|
||||||
self.cfg = 1.0
|
self.cfg = 1.0
|
||||||
@ -714,19 +847,17 @@ class CFGGuider:
|
|||||||
|
|
||||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||||
|
|
||||||
extra_args = {"model_options": self.model_options, "seed":seed}
|
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
|
||||||
|
|
||||||
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
sampler.sample,
|
||||||
|
sampler,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
|
||||||
|
)
|
||||||
|
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
if sigmas.shape[-1] == 0:
|
|
||||||
return latent_image
|
|
||||||
|
|
||||||
self.conds = {}
|
|
||||||
for k in self.original_conds:
|
|
||||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
|
||||||
|
|
||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
@ -737,14 +868,48 @@ class CFGGuider:
|
|||||||
latent_image = latent_image.to(device)
|
latent_image = latent_image.to(device)
|
||||||
sigmas = sigmas.to(device)
|
sigmas = sigmas.to(device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.model_patcher.pre_run()
|
||||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
finally:
|
||||||
|
self.model_patcher.cleanup()
|
||||||
|
|
||||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||||
del self.inner_model
|
del self.inner_model
|
||||||
del self.conds
|
|
||||||
del self.loaded_models
|
del self.loaded_models
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
|
if sigmas.shape[-1] == 0:
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
self.conds = {}
|
||||||
|
for k in self.original_conds:
|
||||||
|
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||||
|
preprocess_conds_hooks(self.conds)
|
||||||
|
|
||||||
|
try:
|
||||||
|
orig_model_options = self.model_options
|
||||||
|
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||||
|
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
|
||||||
|
orig_hook_mode = self.model_patcher.hook_mode
|
||||||
|
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||||
|
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
|
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||||
|
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self.outer_sample,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||||
|
)
|
||||||
|
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
finally:
|
||||||
|
self.model_options = orig_model_options
|
||||||
|
self.model_patcher.hook_mode = orig_hook_mode
|
||||||
|
self.model_patcher.restore_hook_patches()
|
||||||
|
|
||||||
|
del self.conds
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
cfg_guider = CFGGuider(model)
|
cfg_guider = CFGGuider(model)
|
||||||
|
73
comfy/sd.py
73
comfy/sd.py
@ -1,8 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
from .ldm.cascade.stage_a import StageA
|
from .ldm.cascade.stage_a import StageA
|
||||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||||
@ -33,6 +35,7 @@ import comfy.text_encoders.lt
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.lora_convert
|
import comfy.lora_convert
|
||||||
|
import comfy.hooks
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.taesd.taesd
|
import comfy.taesd.taesd
|
||||||
|
|
||||||
@ -98,9 +101,13 @@ class CLIP:
|
|||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
|
self.patcher.is_clip = True
|
||||||
|
self.apply_hooks_to_conds = None
|
||||||
if params['device'] == load_device:
|
if params['device'] == load_device:
|
||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
|
self.use_clip_schedule = False
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
@ -109,6 +116,8 @@ class CLIP:
|
|||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
n.layer_idx = self.layer_idx
|
n.layer_idx = self.layer_idx
|
||||||
|
n.use_clip_schedule = self.use_clip_schedule
|
||||||
|
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
@ -120,6 +129,69 @@ class CLIP:
|
|||||||
def tokenize(self, text, return_word_ids=False):
|
def tokenize(self, text, return_word_ids=False):
|
||||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|
||||||
|
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||||
|
if self.apply_hooks_to_conds:
|
||||||
|
pooled_dict["hooks"] = self.apply_hooks_to_conds
|
||||||
|
return pooled_dict
|
||||||
|
|
||||||
|
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True):
|
||||||
|
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
|
||||||
|
all_hooks = self.patcher.forced_hooks
|
||||||
|
if all_hooks is None or not self.use_clip_schedule:
|
||||||
|
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
|
||||||
|
return_pooled = "unprojected" if unprojected else True
|
||||||
|
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
|
||||||
|
cond = pooled_dict.pop("cond")
|
||||||
|
# add/update any keys with the provided add_dict
|
||||||
|
pooled_dict.update(add_dict)
|
||||||
|
all_cond_pooled.append([cond, pooled_dict])
|
||||||
|
else:
|
||||||
|
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
|
||||||
|
|
||||||
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
if self.layer_idx is not None:
|
||||||
|
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||||
|
if unprojected:
|
||||||
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
|
self.load_model()
|
||||||
|
all_hooks.reset()
|
||||||
|
self.patcher.patch_hooks(None)
|
||||||
|
if show_pbar:
|
||||||
|
pbar = ProgressBar(len(scheduled_keyframes))
|
||||||
|
|
||||||
|
for scheduled_opts in scheduled_keyframes:
|
||||||
|
t_range = scheduled_opts[0]
|
||||||
|
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||||
|
if "start_percent" in add_dict:
|
||||||
|
if t_range[1] < add_dict["start_percent"]:
|
||||||
|
continue
|
||||||
|
if "end_percent" in add_dict:
|
||||||
|
if t_range[0] > add_dict["end_percent"]:
|
||||||
|
continue
|
||||||
|
hooks_keyframes = scheduled_opts[1]
|
||||||
|
for hook, keyframe in hooks_keyframes:
|
||||||
|
hook.hook_keyframe._current_keyframe = keyframe
|
||||||
|
# apply appropriate hooks with values that match new hook_keyframe
|
||||||
|
self.patcher.patch_hooks(all_hooks)
|
||||||
|
# perform encoding as normal
|
||||||
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
|
cond, pooled = o[:2]
|
||||||
|
pooled_dict = {"pooled_output": pooled}
|
||||||
|
# add clip_start_percent and clip_end_percent in pooled
|
||||||
|
pooled_dict["clip_start_percent"] = t_range[0]
|
||||||
|
pooled_dict["clip_end_percent"] = t_range[1]
|
||||||
|
# add/update any keys with the provided add_dict
|
||||||
|
pooled_dict.update(add_dict)
|
||||||
|
# add hooks stored on clip
|
||||||
|
self.add_hooks_to_dict(pooled_dict)
|
||||||
|
all_cond_pooled.append([cond, pooled_dict])
|
||||||
|
if show_pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
|
all_hooks.reset()
|
||||||
|
return all_cond_pooled
|
||||||
|
|
||||||
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
||||||
self.cond_stage_model.reset_clip_options()
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
@ -137,6 +209,7 @@ class CLIP:
|
|||||||
if len(o) > 2:
|
if len(o) > 2:
|
||||||
for k in o[2]:
|
for k in o[2]:
|
||||||
out[k] = o[2][k]
|
out[k] = o[2][k]
|
||||||
|
self.add_hooks_to_dict(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
if return_pooled:
|
if return_pooled:
|
||||||
|
@ -17,8 +17,7 @@ class CLIPTextEncodeSDXLRefiner:
|
|||||||
|
|
||||||
def encode(self, clip, ascore, width, height, text):
|
def encode(self, clip, ascore, width, height, text):
|
||||||
tokens = clip.tokenize(text)
|
tokens = clip.tokenize(text)
|
||||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), )
|
||||||
return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
|
|
||||||
|
|
||||||
class CLIPTextEncodeSDXL:
|
class CLIPTextEncodeSDXL:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -47,8 +46,7 @@ class CLIPTextEncodeSDXL:
|
|||||||
tokens["l"] += empty["l"]
|
tokens["l"] += empty["l"]
|
||||||
while len(tokens["l"]) > len(tokens["g"]):
|
while len(tokens["l"]) > len(tokens["g"]):
|
||||||
tokens["g"] += empty["g"]
|
tokens["g"] += empty["g"]
|
||||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), )
|
||||||
return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
|
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
|
||||||
|
@ -18,10 +18,7 @@ class CLIPTextEncodeFlux:
|
|||||||
tokens = clip.tokenize(clip_l)
|
tokens = clip.tokenize(clip_l)
|
||||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||||
|
|
||||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
|
||||||
cond = output.pop("cond")
|
|
||||||
output["guidance"] = guidance
|
|
||||||
return ([[cond, output]], )
|
|
||||||
|
|
||||||
class FluxGuidance:
|
class FluxGuidance:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
697
comfy_extras/nodes_hooks.py
Normal file
697
comfy_extras/nodes_hooks.py
Normal file
@ -0,0 +1,697 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import TYPE_CHECKING, Union
|
||||||
|
import torch
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.sd import CLIP
|
||||||
|
|
||||||
|
import comfy.hooks
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Mask, Combine, and Hook Conditioning
|
||||||
|
#------------------------------------------
|
||||||
|
class PairConditioningSetProperties:
|
||||||
|
NodeId = 'PairConditioningSetProperties'
|
||||||
|
NodeName = 'Cond Pair Set Props'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"positive_NEW": ("CONDITIONING", ),
|
||||||
|
"negative_NEW": ("CONDITIONING", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"set_cond_area": (["default", "mask bounds"],),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"hooks": ("HOOKS",),
|
||||||
|
"timesteps": ("TIMESTEPS_RANGE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
|
RETURN_NAMES = ("positive", "negative")
|
||||||
|
CATEGORY = "advanced/hooks/cond pair"
|
||||||
|
FUNCTION = "set_properties"
|
||||||
|
|
||||||
|
def set_properties(self, positive_NEW, negative_NEW,
|
||||||
|
strength: float, set_cond_area: str,
|
||||||
|
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||||
|
final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
|
||||||
|
strength=strength, set_cond_area=set_cond_area,
|
||||||
|
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||||
|
return (final_positive, final_negative)
|
||||||
|
|
||||||
|
class PairConditioningSetPropertiesAndCombine:
|
||||||
|
NodeId = 'PairConditioningSetPropertiesAndCombine'
|
||||||
|
NodeName = 'Cond Pair Set Props Combine'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"positive_NEW": ("CONDITIONING", ),
|
||||||
|
"negative_NEW": ("CONDITIONING", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"set_cond_area": (["default", "mask bounds"],),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"hooks": ("HOOKS",),
|
||||||
|
"timesteps": ("TIMESTEPS_RANGE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
|
RETURN_NAMES = ("positive", "negative")
|
||||||
|
CATEGORY = "advanced/hooks/cond pair"
|
||||||
|
FUNCTION = "set_properties"
|
||||||
|
|
||||||
|
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
|
||||||
|
strength: float, set_cond_area: str,
|
||||||
|
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||||
|
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
|
||||||
|
strength=strength, set_cond_area=set_cond_area,
|
||||||
|
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||||
|
return (final_positive, final_negative)
|
||||||
|
|
||||||
|
class ConditioningSetProperties:
|
||||||
|
NodeId = 'ConditioningSetProperties'
|
||||||
|
NodeName = 'Cond Set Props'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"cond_NEW": ("CONDITIONING", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"set_cond_area": (["default", "mask bounds"],),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"hooks": ("HOOKS",),
|
||||||
|
"timesteps": ("TIMESTEPS_RANGE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
CATEGORY = "advanced/hooks/cond single"
|
||||||
|
FUNCTION = "set_properties"
|
||||||
|
|
||||||
|
def set_properties(self, cond_NEW,
|
||||||
|
strength: float, set_cond_area: str,
|
||||||
|
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||||
|
(final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
|
||||||
|
strength=strength, set_cond_area=set_cond_area,
|
||||||
|
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||||
|
return (final_cond,)
|
||||||
|
|
||||||
|
class ConditioningSetPropertiesAndCombine:
|
||||||
|
NodeId = 'ConditioningSetPropertiesAndCombine'
|
||||||
|
NodeName = 'Cond Set Props Combine'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"cond": ("CONDITIONING", ),
|
||||||
|
"cond_NEW": ("CONDITIONING", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"set_cond_area": (["default", "mask bounds"],),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"hooks": ("HOOKS",),
|
||||||
|
"timesteps": ("TIMESTEPS_RANGE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
CATEGORY = "advanced/hooks/cond single"
|
||||||
|
FUNCTION = "set_properties"
|
||||||
|
|
||||||
|
def set_properties(self, cond, cond_NEW,
|
||||||
|
strength: float, set_cond_area: str,
|
||||||
|
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||||
|
(final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
|
||||||
|
strength=strength, set_cond_area=set_cond_area,
|
||||||
|
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||||
|
return (final_cond,)
|
||||||
|
|
||||||
|
class PairConditioningCombine:
|
||||||
|
NodeId = 'PairConditioningCombine'
|
||||||
|
NodeName = 'Cond Pair Combine'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"positive_A": ("CONDITIONING",),
|
||||||
|
"negative_A": ("CONDITIONING",),
|
||||||
|
"positive_B": ("CONDITIONING",),
|
||||||
|
"negative_B": ("CONDITIONING",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
|
RETURN_NAMES = ("positive", "negative")
|
||||||
|
CATEGORY = "advanced/hooks/cond pair"
|
||||||
|
FUNCTION = "combine"
|
||||||
|
|
||||||
|
def combine(self, positive_A, negative_A, positive_B, negative_B):
|
||||||
|
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
|
||||||
|
return (final_positive, final_negative,)
|
||||||
|
|
||||||
|
class PairConditioningSetDefaultAndCombine:
|
||||||
|
NodeId = 'PairConditioningSetDefaultCombine'
|
||||||
|
NodeName = 'Cond Pair Set Default Combine'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"positive": ("CONDITIONING",),
|
||||||
|
"negative": ("CONDITIONING",),
|
||||||
|
"positive_DEFAULT": ("CONDITIONING",),
|
||||||
|
"negative_DEFAULT": ("CONDITIONING",),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"hooks": ("HOOKS",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
|
RETURN_NAMES = ("positive", "negative")
|
||||||
|
CATEGORY = "advanced/hooks/cond pair"
|
||||||
|
FUNCTION = "set_default_and_combine"
|
||||||
|
|
||||||
|
def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
|
||||||
|
hooks: comfy.hooks.HookGroup=None):
|
||||||
|
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
|
||||||
|
hooks=hooks)
|
||||||
|
return (final_positive, final_negative)
|
||||||
|
|
||||||
|
class ConditioningSetDefaultAndCombine:
|
||||||
|
NodeId = 'ConditioningSetDefaultCombine'
|
||||||
|
NodeName = 'Cond Set Default Combine'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"cond": ("CONDITIONING",),
|
||||||
|
"cond_DEFAULT": ("CONDITIONING",),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"hooks": ("HOOKS",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
CATEGORY = "advanced/hooks/cond single"
|
||||||
|
FUNCTION = "set_default_and_combine"
|
||||||
|
|
||||||
|
def set_default_and_combine(self, cond, cond_DEFAULT,
|
||||||
|
hooks: comfy.hooks.HookGroup=None):
|
||||||
|
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
|
||||||
|
hooks=hooks)
|
||||||
|
return (final_conditioning,)
|
||||||
|
|
||||||
|
class SetClipHooks:
|
||||||
|
NodeId = 'SetClipHooks'
|
||||||
|
NodeName = 'Set CLIP Hooks'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"clip": ("CLIP",),
|
||||||
|
"apply_to_conds": ("BOOLEAN", {"default": True}),
|
||||||
|
"schedule_clip": ("BOOLEAN", {"default": False})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"hooks": ("HOOKS",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
CATEGORY = "advanced/hooks/clip"
|
||||||
|
FUNCTION = "apply_hooks"
|
||||||
|
|
||||||
|
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||||
|
if hooks is not None:
|
||||||
|
clip = clip.clone()
|
||||||
|
if apply_to_conds:
|
||||||
|
clip.apply_hooks_to_conds = hooks
|
||||||
|
clip.patcher.forced_hooks = hooks.clone()
|
||||||
|
clip.use_clip_schedule = schedule_clip
|
||||||
|
if not clip.use_clip_schedule:
|
||||||
|
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||||
|
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
|
||||||
|
return (clip,)
|
||||||
|
|
||||||
|
class ConditioningTimestepsRange:
|
||||||
|
NodeId = 'ConditioningTimestepsRange'
|
||||||
|
NodeName = 'Timesteps Range'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
|
||||||
|
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
|
||||||
|
CATEGORY = "advanced/hooks"
|
||||||
|
FUNCTION = "create_range"
|
||||||
|
|
||||||
|
def create_range(self, start_percent: float, end_percent: float):
|
||||||
|
return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0))
|
||||||
|
#------------------------------------------
|
||||||
|
###########################################
|
||||||
|
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Create Hooks
|
||||||
|
#------------------------------------------
|
||||||
|
class CreateHookLora:
|
||||||
|
NodeId = 'CreateHookLora'
|
||||||
|
NodeName = 'Create Hook LoRA'
|
||||||
|
def __init__(self):
|
||||||
|
self.loaded_lora = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||||
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||||
|
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"prev_hooks": ("HOOKS",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOKS",)
|
||||||
|
CATEGORY = "advanced/hooks/create"
|
||||||
|
FUNCTION = "create_hook"
|
||||||
|
|
||||||
|
def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None):
|
||||||
|
if prev_hooks is None:
|
||||||
|
prev_hooks = comfy.hooks.HookGroup()
|
||||||
|
prev_hooks.clone()
|
||||||
|
|
||||||
|
if strength_model == 0 and strength_clip == 0:
|
||||||
|
return (prev_hooks,)
|
||||||
|
|
||||||
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||||
|
lora = None
|
||||||
|
if self.loaded_lora is not None:
|
||||||
|
if self.loaded_lora[0] == lora_path:
|
||||||
|
lora = self.loaded_lora[1]
|
||||||
|
else:
|
||||||
|
temp = self.loaded_lora
|
||||||
|
self.loaded_lora = None
|
||||||
|
del temp
|
||||||
|
|
||||||
|
if lora is None:
|
||||||
|
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||||
|
self.loaded_lora = (lora_path, lora)
|
||||||
|
|
||||||
|
hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip)
|
||||||
|
return (prev_hooks.clone_and_combine(hooks),)
|
||||||
|
|
||||||
|
class CreateHookLoraModelOnly(CreateHookLora):
|
||||||
|
NodeId = 'CreateHookLoraModelOnly'
|
||||||
|
NodeName = 'Create Hook LoRA (MO)'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||||
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"prev_hooks": ("HOOKS",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOKS",)
|
||||||
|
CATEGORY = "advanced/hooks/create"
|
||||||
|
FUNCTION = "create_hook_model_only"
|
||||||
|
|
||||||
|
def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None):
|
||||||
|
return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks)
|
||||||
|
|
||||||
|
class CreateHookModelAsLora:
|
||||||
|
NodeId = 'CreateHookModelAsLora'
|
||||||
|
NodeName = 'Create Hook Model as LoRA'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# when not None, will be in following format:
|
||||||
|
# (ckpt_path: str, weights_model: dict, weights_clip: dict)
|
||||||
|
self.loaded_weights = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||||
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||||
|
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"prev_hooks": ("HOOKS",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOKS",)
|
||||||
|
CATEGORY = "advanced/hooks/create"
|
||||||
|
FUNCTION = "create_hook"
|
||||||
|
|
||||||
|
def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float,
|
||||||
|
prev_hooks: comfy.hooks.HookGroup=None):
|
||||||
|
if prev_hooks is None:
|
||||||
|
prev_hooks = comfy.hooks.HookGroup()
|
||||||
|
prev_hooks.clone()
|
||||||
|
|
||||||
|
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||||
|
weights_model = None
|
||||||
|
weights_clip = None
|
||||||
|
if self.loaded_weights is not None:
|
||||||
|
if self.loaded_weights[0] == ckpt_path:
|
||||||
|
weights_model = self.loaded_weights[1]
|
||||||
|
weights_clip = self.loaded_weights[2]
|
||||||
|
else:
|
||||||
|
temp = self.loaded_weights
|
||||||
|
self.loaded_weights = None
|
||||||
|
del temp
|
||||||
|
|
||||||
|
if weights_model is None:
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
|
||||||
|
weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1])
|
||||||
|
self.loaded_weights = (ckpt_path, weights_model, weights_clip)
|
||||||
|
|
||||||
|
hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip,
|
||||||
|
strength_model=strength_model, strength_clip=strength_clip)
|
||||||
|
return (prev_hooks.clone_and_combine(hooks),)
|
||||||
|
|
||||||
|
class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
|
||||||
|
NodeId = 'CreateHookModelAsLoraModelOnly'
|
||||||
|
NodeName = 'Create Hook Model as LoRA (MO)'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||||
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"prev_hooks": ("HOOKS",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOKS",)
|
||||||
|
CATEGORY = "advanced/hooks/create"
|
||||||
|
FUNCTION = "create_hook_model_only"
|
||||||
|
|
||||||
|
def create_hook_model_only(self, ckpt_name: str, strength_model: float,
|
||||||
|
prev_hooks: comfy.hooks.HookGroup=None):
|
||||||
|
return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks)
|
||||||
|
#------------------------------------------
|
||||||
|
###########################################
|
||||||
|
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Schedule Hooks
|
||||||
|
#------------------------------------------
|
||||||
|
class SetHookKeyframes:
|
||||||
|
NodeId = 'SetHookKeyframes'
|
||||||
|
NodeName = 'Set Hook Keyframes'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"hooks": ("HOOKS",),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"hook_kf": ("HOOK_KEYFRAMES",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOKS",)
|
||||||
|
CATEGORY = "advanced/hooks/scheduling"
|
||||||
|
FUNCTION = "set_hook_keyframes"
|
||||||
|
|
||||||
|
def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||||
|
if hook_kf is not None:
|
||||||
|
hooks = hooks.clone()
|
||||||
|
hooks.set_keyframes_on_hooks(hook_kf=hook_kf)
|
||||||
|
return (hooks,)
|
||||||
|
|
||||||
|
class CreateHookKeyframe:
|
||||||
|
NodeId = 'CreateHookKeyframe'
|
||||||
|
NodeName = 'Create Hook Keyframe'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||||
|
RETURN_NAMES = ("HOOK_KF",)
|
||||||
|
CATEGORY = "advanced/hooks/scheduling"
|
||||||
|
FUNCTION = "create_hook_keyframe"
|
||||||
|
|
||||||
|
def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||||
|
if prev_hook_kf is None:
|
||||||
|
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||||
|
prev_hook_kf = prev_hook_kf.clone()
|
||||||
|
keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent)
|
||||||
|
prev_hook_kf.add(keyframe)
|
||||||
|
return (prev_hook_kf,)
|
||||||
|
|
||||||
|
class CreateHookKeyframesFromFloats:
|
||||||
|
NodeId = 'CreateHookKeyframesFromFloats'
|
||||||
|
NodeName = 'Create Hook Keyframes From Floats'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"print_keyframes": ("BOOLEAN", {"default": False}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||||
|
RETURN_NAMES = ("HOOK_KF",)
|
||||||
|
CATEGORY = "advanced/hooks/scheduling"
|
||||||
|
FUNCTION = "create_hook_keyframes"
|
||||||
|
|
||||||
|
def create_hook_keyframes(self, floats_strength: Union[float, list[float]],
|
||||||
|
start_percent: float, end_percent: float,
|
||||||
|
prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False):
|
||||||
|
if prev_hook_kf is None:
|
||||||
|
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||||
|
prev_hook_kf = prev_hook_kf.clone()
|
||||||
|
if type(floats_strength) in (float, int):
|
||||||
|
floats_strength = [float(floats_strength)]
|
||||||
|
elif isinstance(floats_strength, Iterable):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
|
||||||
|
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
|
||||||
|
method=comfy.hooks.InterpolationMethod.LINEAR)
|
||||||
|
|
||||||
|
is_first = True
|
||||||
|
for percent, strength in zip(percents, floats_strength):
|
||||||
|
guarantee_steps = 0
|
||||||
|
if is_first:
|
||||||
|
guarantee_steps = 1
|
||||||
|
is_first = False
|
||||||
|
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
||||||
|
if print_keyframes:
|
||||||
|
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
||||||
|
return (prev_hook_kf,)
|
||||||
|
#------------------------------------------
|
||||||
|
###########################################
|
||||||
|
|
||||||
|
|
||||||
|
class SetModelHooksOnCond:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"conditioning": ("CONDITIONING",),
|
||||||
|
"hooks": ("HOOKS",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
CATEGORY = "advanced/hooks/manual"
|
||||||
|
FUNCTION = "attach_hook"
|
||||||
|
|
||||||
|
def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
|
||||||
|
return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
|
||||||
|
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Combine Hooks
|
||||||
|
#------------------------------------------
|
||||||
|
class CombineHooks:
|
||||||
|
NodeId = 'CombineHooks2'
|
||||||
|
NodeName = 'Combine Hooks [2]'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"hooks_A": ("HOOKS",),
|
||||||
|
"hooks_B": ("HOOKS",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOKS",)
|
||||||
|
CATEGORY = "advanced/hooks/combine"
|
||||||
|
FUNCTION = "combine_hooks"
|
||||||
|
|
||||||
|
def combine_hooks(self,
|
||||||
|
hooks_A: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_B: comfy.hooks.HookGroup=None):
|
||||||
|
candidates = [hooks_A, hooks_B]
|
||||||
|
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||||
|
|
||||||
|
class CombineHooksFour:
|
||||||
|
NodeId = 'CombineHooks4'
|
||||||
|
NodeName = 'Combine Hooks [4]'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"hooks_A": ("HOOKS",),
|
||||||
|
"hooks_B": ("HOOKS",),
|
||||||
|
"hooks_C": ("HOOKS",),
|
||||||
|
"hooks_D": ("HOOKS",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOKS",)
|
||||||
|
CATEGORY = "advanced/hooks/combine"
|
||||||
|
FUNCTION = "combine_hooks"
|
||||||
|
|
||||||
|
def combine_hooks(self,
|
||||||
|
hooks_A: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_B: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_C: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_D: comfy.hooks.HookGroup=None):
|
||||||
|
candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
|
||||||
|
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||||
|
|
||||||
|
class CombineHooksEight:
|
||||||
|
NodeId = 'CombineHooks8'
|
||||||
|
NodeName = 'Combine Hooks [8]'
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"hooks_A": ("HOOKS",),
|
||||||
|
"hooks_B": ("HOOKS",),
|
||||||
|
"hooks_C": ("HOOKS",),
|
||||||
|
"hooks_D": ("HOOKS",),
|
||||||
|
"hooks_E": ("HOOKS",),
|
||||||
|
"hooks_F": ("HOOKS",),
|
||||||
|
"hooks_G": ("HOOKS",),
|
||||||
|
"hooks_H": ("HOOKS",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
RETURN_TYPES = ("HOOKS",)
|
||||||
|
CATEGORY = "advanced/hooks/combine"
|
||||||
|
FUNCTION = "combine_hooks"
|
||||||
|
|
||||||
|
def combine_hooks(self,
|
||||||
|
hooks_A: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_B: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_C: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_D: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_E: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_F: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_G: comfy.hooks.HookGroup=None,
|
||||||
|
hooks_H: comfy.hooks.HookGroup=None):
|
||||||
|
candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
|
||||||
|
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||||
|
#------------------------------------------
|
||||||
|
###########################################
|
||||||
|
|
||||||
|
node_list = [
|
||||||
|
# Create
|
||||||
|
CreateHookLora,
|
||||||
|
CreateHookLoraModelOnly,
|
||||||
|
CreateHookModelAsLora,
|
||||||
|
CreateHookModelAsLoraModelOnly,
|
||||||
|
# Scheduling
|
||||||
|
SetHookKeyframes,
|
||||||
|
CreateHookKeyframe,
|
||||||
|
CreateHookKeyframesFromFloats,
|
||||||
|
# Combine
|
||||||
|
CombineHooks,
|
||||||
|
CombineHooksFour,
|
||||||
|
CombineHooksEight,
|
||||||
|
# Attach
|
||||||
|
ConditioningSetProperties,
|
||||||
|
ConditioningSetPropertiesAndCombine,
|
||||||
|
PairConditioningSetProperties,
|
||||||
|
PairConditioningSetPropertiesAndCombine,
|
||||||
|
ConditioningSetDefaultAndCombine,
|
||||||
|
PairConditioningSetDefaultAndCombine,
|
||||||
|
PairConditioningCombine,
|
||||||
|
SetClipHooks,
|
||||||
|
# Other
|
||||||
|
ConditioningTimestepsRange,
|
||||||
|
]
|
||||||
|
NODE_CLASS_MAPPINGS = {}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
|
|
||||||
|
for node in node_list:
|
||||||
|
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
@ -15,9 +15,7 @@ class CLIPTextEncodeHunyuanDiT:
|
|||||||
tokens = clip.tokenize(bert)
|
tokens = clip.tokenize(bert)
|
||||||
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
||||||
|
|
||||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
cond = output.pop("cond")
|
|
||||||
return ([[cond, output]], )
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
@ -82,8 +82,7 @@ class CLIPTextEncodeSD3:
|
|||||||
tokens["l"] += empty["l"]
|
tokens["l"] += empty["l"]
|
||||||
while len(tokens["l"]) > len(tokens["g"]):
|
while len(tokens["l"]) > len(tokens["g"]):
|
||||||
tokens["g"] += empty["g"]
|
tokens["g"] += empty["g"]
|
||||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
return ([[cond, {"pooled_output": pooled}]], )
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||||
|
6
nodes.py
6
nodes.py
@ -62,9 +62,8 @@ class CLIPTextEncode:
|
|||||||
|
|
||||||
def encode(self, clip, text):
|
def encode(self, clip, text):
|
||||||
tokens = clip.tokenize(text)
|
tokens = clip.tokenize(text)
|
||||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
cond = output.pop("cond")
|
|
||||||
return ([[cond, output]], )
|
|
||||||
|
|
||||||
class ConditioningCombine:
|
class ConditioningCombine:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -2149,6 +2148,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_mochi.py",
|
"nodes_mochi.py",
|
||||||
"nodes_slg.py",
|
"nodes_slg.py",
|
||||||
"nodes_lt.py",
|
"nodes_lt.py",
|
||||||
|
"nodes_hooks.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
Loading…
Reference in New Issue
Block a user