Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch

This commit is contained in:
kosinkadink1@gmail.com 2024-09-14 16:10:42 +09:00
parent 1268d04295
commit f160d46340
3 changed files with 36 additions and 21 deletions

View File

@ -21,6 +21,7 @@ class HookWeight:
return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'):
self.reset()
self.hook_keyframe.initalize_timesteps(model)
def reset(self):

View File

@ -173,10 +173,17 @@ class ModelPatcher:
return True
return False
def clone_has_same_weights(self, clone):
def clone_has_same_weights(self, clone: 'ModelPatcher'):
if not self.is_clone(clone):
return False
if len(self.hook_patches) > 0: # TODO: check if this workaround is necessary
return False
if self.current_hooks != clone.current_hooks:
return False
if self.hook_patches.keys() != clone.hook_patches.keys():
return False
if len(self.patches) == 0 and len(clone.patches) == 0:
return True
@ -557,9 +564,8 @@ class ModelPatcher:
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_groups: List[comfy.hooks.HookWeightGroup]):
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookWeightGroup):
curr_t = t[0]
for hook_group in hook_groups:
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
# if keyframe changed, remove any cached LoraHookGroups that contain hook with the same hook_ref;
@ -596,7 +602,7 @@ class ModelPatcher:
current_patches: List[Tuple] = current_hook_patches.get(key, [])
if is_diff:
# take difference between desired weight and existing weight to get diff
# TODO: try to implement diff his via strength_path/strength_model diff
# TODO: try to implement diff via strength_path/strength_model diff
model_dtype = comfy.utils.get_attr(self.model, key).dtype
if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, key).to(torch.float32)).to(model_dtype)

View File

@ -157,9 +157,11 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
if p is None:
continue
hook: comfy.hooks.HookWeightGroup = x.get('hooks', None)
hooked_to_run.setdefault(hook, list())
hooked_to_run[hook] += [(p, i)]
hooks: comfy.hooks.HookWeightGroup = x.get('hooks', None)
if hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, hooks)
hooked_to_run.setdefault(hooks, list())
hooked_to_run[hooks] += [(p, i)]
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@ -658,6 +660,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if k != kk:
create_cond_with_same_area_if_none(conds[kk], c)
for k in conds:
for c in conds[k]:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook.initialize_timesteps(model)
for k in conds:
pre_run_control(model, conds[k])