mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-28 10:27:02 +08:00
Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch
This commit is contained in:
parent
1268d04295
commit
f160d46340
@ -21,6 +21,7 @@ class HookWeight:
|
||||
return self.hook_keyframe.strength
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
self.reset()
|
||||
self.hook_keyframe.initalize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
|
@ -173,10 +173,17 @@ class ModelPatcher:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone):
|
||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if len(self.hook_patches) > 0: # TODO: check if this workaround is necessary
|
||||
return False
|
||||
if self.current_hooks != clone.current_hooks:
|
||||
return False
|
||||
if self.hook_patches.keys() != clone.hook_patches.keys():
|
||||
return False
|
||||
|
||||
if len(self.patches) == 0 and len(clone.patches) == 0:
|
||||
return True
|
||||
|
||||
@ -557,23 +564,22 @@ class ModelPatcher:
|
||||
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
||||
self.hook_mode = hook_mode
|
||||
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_groups: List[comfy.hooks.HookWeightGroup]):
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookWeightGroup):
|
||||
curr_t = t[0]
|
||||
for hook_group in hook_groups:
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
|
||||
# if keyframe changed, remove any cached LoraHookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# reset current_lora_hooks if contains lora hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
self.current_hooks = None
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
|
||||
# if keyframe changed, remove any cached LoraHookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# reset current_lora_hooks if contains lora hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
self.current_hooks = None
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
|
||||
def add_hook_patches(self, hook: comfy.hooks.HookWeight, patches, strength_patch=1.0, strength_model=1.0, is_diff=False):
|
||||
# NOTE: this mirrors behavior of add_patches func
|
||||
@ -596,7 +602,7 @@ class ModelPatcher:
|
||||
current_patches: List[Tuple] = current_hook_patches.get(key, [])
|
||||
if is_diff:
|
||||
# take difference between desired weight and existing weight to get diff
|
||||
# TODO: try to implement diff his via strength_path/strength_model diff
|
||||
# TODO: try to implement diff via strength_path/strength_model diff
|
||||
model_dtype = comfy.utils.get_attr(self.model, key).dtype
|
||||
if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
|
||||
diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, key).to(torch.float32)).to(model_dtype)
|
||||
|
@ -157,9 +157,11 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
hook: comfy.hooks.HookWeightGroup = x.get('hooks', None)
|
||||
hooked_to_run.setdefault(hook, list())
|
||||
hooked_to_run[hook] += [(p, i)]
|
||||
hooks: comfy.hooks.HookWeightGroup = x.get('hooks', None)
|
||||
if hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, hooks)
|
||||
hooked_to_run.setdefault(hooks, list())
|
||||
hooked_to_run[hooks] += [(p, i)]
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@ -658,6 +660,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
if k != kk:
|
||||
create_cond_with_same_area_if_none(conds[kk], c)
|
||||
|
||||
for k in conds:
|
||||
for c in conds[k]:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook.initialize_timesteps(model)
|
||||
|
||||
for k in conds:
|
||||
pre_run_control(model, conds[k])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user