from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import enum
import math
import torch
import numpy as np
import itertools
import logging

if TYPE_CHECKING:
    from comfy.model_patcher import ModelPatcher, PatcherInjection
    from comfy.model_base import BaseModel
    from comfy.sd import CLIP
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values

# #######################################################################################################
# Hooks explanation
# -------------------
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
# make explicit special cases like it does for ControlNet and GLIGEN.
#
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
# that should run special code when a 'marked' cond is used in sampling.
# #######################################################################################################

class EnumHookMode(enum.Enum):
    '''
    Priority of hook memory optimization vs. speed, mostly related to WeightHooks.

    MinVram: No caching will occur for any operations related to hooks.
    MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
    '''
    MinVram = "minvram"
    MaxSpeed = "maxspeed"

class EnumHookType(enum.Enum):
    '''
    Hook types, each of which has different expected behavior.
    '''
    Weight = "weight"
    ObjectPatch = "object_patch"
    AdditionalModels = "add_models"
    TransformerOptions = "transformer_options"
    Injections = "add_injections"

class EnumWeightTarget(enum.Enum):
    Model = "model"
    Clip = "clip"

class EnumHookScope(enum.Enum):
    '''
    Determines if hook should be limited in its influence over sampling.

    AllConditioning: hook will affect all conds used in sampling.
    HookedOnly: hook will only affect the conds it was attached to.
    '''
    AllConditioning = "all_conditioning"
    HookedOnly = "hooked_only"


class _HookRef:
    pass


def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
    '''Example for how custom_should_register function can look like.'''
    return True


def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
    '''Creates base dictionary for use with Hooks' target param.'''
    d = {}
    if target is not None:
        d['target'] = target
    d.update(kwargs)
    return d


class Hook:
    def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
                 hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
        self.hook_type = hook_type
        '''Enum identifying the general class of this hook.'''
        self.hook_ref = hook_ref if hook_ref else _HookRef()
        '''Reference shared between hook clones that have the same value. Should NOT be modified.'''
        self.hook_id = hook_id
        '''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
        self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
        '''Keyframe storage that can be referenced to get strength for current sampling step.'''
        self.hook_scope = hook_scope
        '''Scope of where this hook should apply in terms of the conds used in sampling run.'''
        self.custom_should_register = default_should_register
        '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''

    @property
    def strength(self):
        return self.hook_keyframe.strength

    def initialize_timesteps(self, model: BaseModel):
        self.reset()
        self.hook_keyframe.initialize_timesteps(model)

    def reset(self):
        self.hook_keyframe.reset()

    def clone(self):
        c: Hook = self.__class__()
        c.hook_type = self.hook_type
        c.hook_ref = self.hook_ref
        c.hook_id = self.hook_id
        c.hook_keyframe = self.hook_keyframe
        c.hook_scope = self.hook_scope
        c.custom_should_register = self.custom_should_register
        return c

    def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
        return self.custom_should_register(self, model, model_options, target_dict, registered)

    def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
        raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")

    def __eq__(self, other: Hook):
        return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref

    def __hash__(self):
        return hash(self.hook_ref)

class WeightHook(Hook):
    '''
    Hook responsible for tracking weights to be applied to some model/clip.

    Note, value of hook_scope is ignored and is treated as HookedOnly.
    '''
    def __init__(self, strength_model=1.0, strength_clip=1.0):
        super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
        self.weights: dict = None
        self.weights_clip: dict = None
        self.need_weight_init = True
        self._strength_model = strength_model
        self._strength_clip = strength_clip
        self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs

    @property
    def strength_model(self):
        return self._strength_model * self.strength

    @property
    def strength_clip(self):
        return self._strength_clip * self.strength

    def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
        if not self.should_register(model, model_options, target_dict, registered):
            return False
        weights = None

        target = target_dict.get('target', None)
        if target == EnumWeightTarget.Clip:
            strength = self._strength_clip
        else:
            strength = self._strength_model

        if self.need_weight_init:
            key_map = {}
            if target == EnumWeightTarget.Clip:
                key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
            else:
                key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
            weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
        else:
            if target == EnumWeightTarget.Clip:
                weights = self.weights_clip
            else:
                weights = self.weights
        model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
        registered.add(self)
        return True
        # TODO: add logs about any keys that were not applied

    def clone(self):
        c: WeightHook = super().clone()
        c.weights = self.weights
        c.weights_clip = self.weights_clip
        c.need_weight_init = self.need_weight_init
        c._strength_model = self._strength_model
        c._strength_clip = self._strength_clip
        return c

class ObjectPatchHook(Hook):
    def __init__(self, object_patches: dict[str]=None,
                 hook_scope=EnumHookScope.AllConditioning):
        super().__init__(hook_type=EnumHookType.ObjectPatch)
        self.object_patches = object_patches
        self.hook_scope = hook_scope

    def clone(self):
        c: ObjectPatchHook = super().clone()
        c.object_patches = self.object_patches
        return c

    def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
        raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")

class AdditionalModelsHook(Hook):
    '''
    Hook responsible for telling model management any additional models that should be loaded.

    Note, value of hook_scope is ignored and is treated as AllConditioning.
    '''
    def __init__(self, models: list[ModelPatcher]=None, key: str=None):
        super().__init__(hook_type=EnumHookType.AdditionalModels)
        self.models = models
        self.key = key

    def clone(self):
        c: AdditionalModelsHook = super().clone()
        c.models = self.models.copy() if self.models else self.models
        c.key = self.key
        return c

    def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
        if not self.should_register(model, model_options, target_dict, registered):
            return False
        registered.add(self)
        return True

class TransformerOptionsHook(Hook):
    '''
    Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
    '''
    def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
                 hook_scope=EnumHookScope.AllConditioning):
        super().__init__(hook_type=EnumHookType.TransformerOptions)
        self.transformers_dict = transformers_dict
        self.hook_scope = hook_scope
        self._skip_adding = False
        '''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''

    def clone(self):
        c: TransformerOptionsHook = super().clone()
        c.transformers_dict = self.transformers_dict
        c._skip_adding = self._skip_adding
        return c

    def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
        if not self.should_register(model, model_options, target_dict, registered):
            return False
        # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
        self._skip_adding = False
        if self.hook_scope == EnumHookScope.AllConditioning:
            add_model_options = {"transformer_options": self.transformers_dict,
                                 "to_load_options": self.transformers_dict}
            # skip_adding if included in AllConditioning to avoid double loading
            self._skip_adding = True
        else:
            add_model_options = {"to_load_options": self.transformers_dict}
        registered.add(self)
        comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
        return True

    def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
        if not self._skip_adding:
            comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)

WrapperHook = TransformerOptionsHook
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''

class InjectionsHook(Hook):
    def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
                 hook_scope=EnumHookScope.AllConditioning):
        super().__init__(hook_type=EnumHookType.Injections)
        self.key = key
        self.injections = injections
        self.hook_scope = hook_scope

    def clone(self):
        c: InjectionsHook = super().clone()
        c.key = self.key
        c.injections = self.injections.copy() if self.injections else self.injections
        return c

    def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
        raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")

class HookGroup:
    '''
    Stores groups of hooks, and allows them to be queried by type.

    To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
    always use the provided functions on HookGroup.
    '''
    def __init__(self):
        self.hooks: list[Hook] = []
        self._hook_dict: dict[EnumHookType, list[Hook]] = {}

    def __len__(self):
        return len(self.hooks)

    def add(self, hook: Hook):
        if hook not in self.hooks:
            self.hooks.append(hook)
            self._hook_dict.setdefault(hook.hook_type, []).append(hook)

    def remove(self, hook: Hook):
        if hook in self.hooks:
            self.hooks.remove(hook)
            self._hook_dict[hook.hook_type].remove(hook)

    def get_type(self, hook_type: EnumHookType):
        return self._hook_dict.get(hook_type, [])

    def contains(self, hook: Hook):
        return hook in self.hooks

    def is_subset_of(self, other: HookGroup):
        self_hooks = set(self.hooks)
        other_hooks = set(other.hooks)
        return self_hooks.issubset(other_hooks)

    def new_with_common_hooks(self, other: HookGroup):
        c = HookGroup()
        for hook in self.hooks:
            if other.contains(hook):
                c.add(hook.clone())
        return c

    def clone(self):
        c = HookGroup()
        for hook in self.hooks:
            c.add(hook.clone())
        return c

    def clone_and_combine(self, other: HookGroup):
        c = self.clone()
        if other is not None:
            for hook in other.hooks:
                c.add(hook.clone())
        return c

    def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
        if hook_kf is None:
            hook_kf = HookKeyframeGroup()
        else:
            hook_kf = hook_kf.clone()
        for hook in self.hooks:
            hook.hook_keyframe = hook_kf

    def get_hooks_for_clip_schedule(self):
        scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
        # only care about WeightHooks, for now
        for hook in self.get_type(EnumHookType.Weight):
            hook: WeightHook
            hook_schedule = []
            # if no hook keyframes, assign default value
            if len(hook.hook_keyframe.keyframes) == 0:
                hook_schedule.append(((0.0, 1.0), None))
                scheduled_hooks[hook] = hook_schedule
                continue
            # find ranges of values
            prev_keyframe = hook.hook_keyframe.keyframes[0]
            for keyframe in hook.hook_keyframe.keyframes:
                if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
                    hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
                    prev_keyframe = keyframe
                elif keyframe.start_percent == prev_keyframe.start_percent:
                    prev_keyframe = keyframe
            # create final range, assuming last start_percent was not 1.0
            if not math.isclose(prev_keyframe.start_percent, 1.0):
                hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
            scheduled_hooks[hook] = hook_schedule
        # hooks should not have their schedules in a list of tuples
        all_ranges: list[tuple[float, float]] = []
        for range_kfs in scheduled_hooks.values():
            for t_range, keyframe in range_kfs:
                all_ranges.append(t_range)
        # turn list of ranges into boundaries
        boundaries_set = set(itertools.chain.from_iterable(all_ranges))
        boundaries_set.add(0.0)
        boundaries = sorted(boundaries_set)
        real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
        # with real ranges defined, give appropriate hooks w/ keyframes for each range
        scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
        for t_range in real_ranges:
            hooks_schedule = []
            for hook, val in scheduled_hooks.items():
                keyframe = None
                # check if is a keyframe that works for the current t_range
                for stored_range, stored_kf in val:
                    # if stored start is less than current end, then fits - give it assigned keyframe
                    if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
                        keyframe = stored_kf
                        break
                hooks_schedule.append((hook, keyframe))
            scheduled_keyframes.append((t_range, hooks_schedule))
        return scheduled_keyframes

    def reset(self):
        for hook in self.hooks:
            hook.reset()

    @staticmethod
    def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
        actual: list[HookGroup] = []
        for group in hooks_list:
            if group is not None:
                actual.append(group)
        if len(actual) < require_count:
            raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
        # if no hooks, then return None
        if len(actual) == 0:
            return None
        # if only 1 hook, just return itself without cloning
        elif len(actual) == 1:
            return actual[0]
        final_hook: HookGroup = None
        for hook in actual:
            if final_hook is None:
                final_hook = hook.clone()
            else:
                final_hook = final_hook.clone_and_combine(hook)
        return final_hook


class HookKeyframe:
    def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
        self.strength = strength
        # scheduling
        self.start_percent = float(start_percent)
        self.start_t = 999999999.9
        self.guarantee_steps = guarantee_steps

    def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
        '''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
        if self.start_t > max_sigma:
            return 0
        return self.guarantee_steps

    def clone(self):
        c = HookKeyframe(strength=self.strength,
                         start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
        c.start_t = self.start_t
        return c

class HookKeyframeGroup:
    def __init__(self):
        self.keyframes: list[HookKeyframe] = []
        self._current_keyframe: HookKeyframe = None
        self._current_used_steps = 0
        self._current_index = 0
        self._current_strength = None
        self._curr_t = -1.

    # properties shadow those of HookWeightsKeyframe
    @property
    def strength(self):
        if self._current_keyframe is not None:
            return self._current_keyframe.strength
        return 1.0

    def reset(self):
        self._current_keyframe = None
        self._current_used_steps = 0
        self._current_index = 0
        self._current_strength = None
        self.curr_t = -1.
        self._set_first_as_current()

    def add(self, keyframe: HookKeyframe):
        # add to end of list, then sort
        self.keyframes.append(keyframe)
        self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
        self._set_first_as_current()

    def _set_first_as_current(self):
        if len(self.keyframes) > 0:
            self._current_keyframe = self.keyframes[0]
        else:
            self._current_keyframe = None

    def has_guarantee_steps(self):
        for kf in self.keyframes:
            if kf.guarantee_steps > 0:
                return True
        return False

    def has_index(self, index: int):
        return index >= 0 and index < len(self.keyframes)

    def is_empty(self):
        return len(self.keyframes) == 0

    def clone(self):
        c = HookKeyframeGroup()
        for keyframe in self.keyframes:
            c.keyframes.append(keyframe.clone())
        c._set_first_as_current()
        return c

    def initialize_timesteps(self, model: BaseModel):
        for keyframe in self.keyframes:
            keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)

    def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
        if self.is_empty():
            return False
        if curr_t == self._curr_t:
            return False
        max_sigma = torch.max(transformer_options["sample_sigmas"])
        prev_index = self._current_index
        prev_strength = self._current_strength
        # if met guaranteed steps, look for next keyframe in case need to switch
        if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
            # if has next index, loop through and see if need to switch
            if self.has_index(self._current_index+1):
                for i in range(self._current_index+1, len(self.keyframes)):
                    eval_c = self.keyframes[i]
                    # check if start_t is greater or equal to curr_t
                    # NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
                    if eval_c.start_t >= curr_t:
                        self._current_index = i
                        self._current_strength = eval_c.strength
                        self._current_keyframe = eval_c
                        self._current_used_steps = 0
                        # if guarantee_steps greater than zero, stop searching for other keyframes
                        if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
                            break
                    # if eval_c is outside the percent range, stop looking further
                    else: break
        # update steps current context is used
        self._current_used_steps += 1
        # update current timestep this was performed on
        self._curr_t = curr_t
        # return True if keyframe changed, False if no change
        return prev_index != self._current_index and prev_strength != self._current_strength


class InterpolationMethod:
    LINEAR = "linear"
    EASE_IN = "ease_in"
    EASE_OUT = "ease_out"
    EASE_IN_OUT = "ease_in_out"

    _LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]

    @classmethod
    def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
        diff = num_to - num_from
        if method == cls.LINEAR:
            weights = torch.linspace(num_from, num_to, length)
        elif method == cls.EASE_IN:
            index = torch.linspace(0, 1, length)
            weights = diff * np.power(index, 2) + num_from
        elif method == cls.EASE_OUT:
            index = torch.linspace(0, 1, length)
            weights = diff * (1 - np.power(1 - index, 2)) + num_from
        elif method == cls.EASE_IN_OUT:
            index = torch.linspace(0, 1, length)
            weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
        else:
            raise ValueError(f"Unrecognized interpolation method '{method}'.")
        if reverse:
            weights = weights.flip(dims=(0,))
        return weights

def get_sorted_list_via_attr(objects: list, attr: str) -> list:
    if not objects:
        return objects
    elif len(objects) <= 1:
        return [x for x in objects]
    # now that we know we have to sort, do it following these rules:
    # a) if objects have same value of attribute, maintain their relative order
    # b) perform sorting of the groups of objects with same attributes
    unique_attrs = {}
    for o in objects:
        val_attr = getattr(o, attr)
        attr_list: list = unique_attrs.get(val_attr, list())
        attr_list.append(o)
        if val_attr not in unique_attrs:
            unique_attrs[val_attr] = attr_list
    # now that we have the unique attr values grouped together in relative order, sort them by key
    sorted_attrs = dict(sorted(unique_attrs.items()))
    # now flatten out the dict into a list to return
    sorted_list = []
    for object_list in sorted_attrs.values():
        sorted_list.extend(object_list)
    return sorted_list

def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup,  transformer_options: dict[str]=None):
    # if no hooks or is not a ModelPatcher for sampling, return empty dict
    if hooks is None or model.is_clip:
        return {}
    if transformer_options is None:
        transformer_options = {}
    for hook in hooks.get_type(EnumHookType.TransformerOptions):
        hook: TransformerOptionsHook
        hook.on_apply_hooks(model, transformer_options)
    return transformer_options

def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
    hook_group = HookGroup()
    hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
    hook_group.add(hook)
    hook.weights = lora
    return hook_group

def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
    hook_group = HookGroup()
    hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
    hook_group.add(hook)
    patches_model = None
    patches_clip = None
    if weights_model is not None:
        patches_model = {}
        for key in weights_model:
            patches_model[key] = ("model_as_lora", (weights_model[key],))
    if weights_clip is not None:
        patches_clip = {}
        for key in weights_clip:
            patches_clip[key] = ("model_as_lora", (weights_clip[key],))
    hook.weights = patches_model
    hook.weights_clip = patches_clip
    hook.need_weight_init = False
    return hook_group

def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
    if model is None:
        return None
    patches_model: dict[str, torch.Tensor] = model.model.state_dict()
    if discard_model_sampling:
        # do not include ANY model_sampling components of the model that should act as a patch
        for key in list(patches_model.keys()):
            if key.startswith("model_sampling"):
                patches_model.pop(key, None)
    return patches_model

# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
                              strength_model: float, strength_clip: float):
    key_map = {}
    if model is not None:
        key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
    if clip is not None:
        key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)

    hook_group = HookGroup()
    hook = WeightHook()
    hook_group.add(hook)
    loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
    if model is not None:
        new_modelpatcher = model.clone()
        k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
    else:
        k = ()
        new_modelpatcher = None

    if clip is not None:
        new_clip = clip.clone()
        k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
    else:
        k1 = ()
        new_clip = None
    k = set(k)
    k1 = set(k1)
    for x in loaded:
        if (x not in k) and (x not in k1):
            logging.warning(f"NOT LOADED {x}")
    return (new_modelpatcher, new_clip, hook_group)

def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
    hooks_key = 'hooks'
    # if hooks only exist in one dict, do what's needed so that it ends up in c_dict
    if hooks_key not in values:
        return
    if hooks_key not in c_dict:
        hooks_value = values.get(hooks_key, None)
        if hooks_value is not None:
            c_dict[hooks_key] = hooks_value
        return
    # otherwise, need to combine with minimum duplication via cache
    hooks_tuple = (c_dict[hooks_key], values[hooks_key])
    cached_hooks = cache.get(hooks_tuple, None)
    if cached_hooks is None:
        new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
        cache[hooks_tuple] = new_hooks
        c_dict[hooks_key] = new_hooks
    else:
        c_dict[hooks_key] = cache[hooks_tuple]

def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
                                       cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
    c = []
    if cache is None:
        cache = {}
    for t in conditioning:
        n = [t[0], t[1].copy()]
        for k in values:
            if append_hooks and k == 'hooks':
                _combine_hooks_from_values(n[1], values, cache)
            else:
                n[1][k] = values[k]
        c.append(n)

    return c

def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
    if hooks is None:
        return cond
    return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)

def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
    if timestep_range is None:
        return cond
    return conditioning_set_values(cond, {"start_percent": timestep_range[0],
                                          "end_percent": timestep_range[1]})

def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
    if mask is None:
        return cond
    set_area_to_bounds = False
    if set_cond_area != 'default':
        set_area_to_bounds = True
    if len(mask.shape) < 3:
        mask = mask.unsqueeze(0)
    return conditioning_set_values(cond, {'mask': mask,
                                          'set_area_to_bounds': set_area_to_bounds,
                                          'mask_strength': strength})

def combine_conditioning(conds: list):
    combined_conds = []
    for cond in conds:
        combined_conds.extend(cond)
    return combined_conds

def combine_with_new_conds(conds: list, new_conds: list):
    combined_conds = []
    for c, new_c in zip(conds, new_conds):
        combined_conds.append(combine_conditioning([c, new_c]))
    return combined_conds

def set_conds_props(conds: list, strength: float, set_cond_area: str,
                   mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
    final_conds = []
    cache = {}
    for c in conds:
        # first, apply lora_hook to conditioning, if provided
        c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
        # next, apply mask to conditioning
        c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
        # apply timesteps, if present
        c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
        # finally, apply mask to conditioning and store
        final_conds.append(c)
    return final_conds

def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
                               mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
    combined_conds = []
    cache = {}
    for c, masked_c in zip(conds, new_conds):
        # first, apply lora_hook to new conditioning, if provided
        masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
        # next, apply mask to new conditioning, if provided
        masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
        # apply timesteps, if present
        masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
        # finally, combine with existing conditioning and store
        combined_conds.append(combine_conditioning([c, masked_c]))
    return combined_conds

def set_default_conds_and_combine(conds: list, new_conds: list,
                                   hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
    combined_conds = []
    cache = {}
    for c, new_c in zip(conds, new_conds):
        # first, apply lora_hook to new conditioning, if provided
        new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
        # next, add default_cond key to cond so that during sampling, it can be identified
        new_c = conditioning_set_values(new_c, {'default': True})
        # apply timesteps, if present
        new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
        # finally, combine with existing conditioning and store
        combined_conds.append(combine_conditioning([c, new_c]))
    return combined_conds