Fixed existing weight hook_patches (pre-registered) not working properly for CLIP

This commit is contained in:
Jedrzej Kosinski 2024-11-12 08:12:12 -06:00
parent 9330745f27
commit 1766d903ad
3 changed files with 39 additions and 36 deletions

View File

@ -215,6 +215,7 @@ class ModelPatcher:
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP
self.is_clip = False
# TODO: hook_mode should be entirely removed; behavior should be determined by remaining VRAM/memory
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
@ -291,6 +292,7 @@ class ModelPatcher:
n.hook_backup = self.hook_backup
n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
n.is_clip = self.is_clip
n.hook_mode = self.hook_mode
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
@ -613,7 +615,7 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
self.apply_hooks(self.forced_hooks)
self.apply_hooks(self.forced_hooks, force_apply=True)
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
with self.use_ejected():
@ -733,6 +735,7 @@ class ModelPatcher:
self.patch_model(load_weights=False)
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
self.apply_hooks(self.forced_hooks, force_apply=True)
return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
@ -991,9 +994,9 @@ class ModelPatcher:
combined_patches[key] = current_patches
return combined_patches
def apply_hooks(self, hooks: comfy.hooks.HookGroup, model_options: dict=None):
def apply_hooks(self, hooks: comfy.hooks.HookGroup, model_options: dict=None, force_apply=False):
# TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks:
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
return {}
self.patch_hooks(hooks=hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
@ -1003,31 +1006,32 @@ class ModelPatcher:
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():
self.unpatch_hooks()
model_sd = self.model_state_dict()
memory_counter = None
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
minimum=comfy.model_management.minimum_inference_memory())
# if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
for key in cached_weights:
if key not in model_sd:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
else:
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None
if len(relevant_patches) > 0:
original_weights = self.get_key_patches()
for key in relevant_patches:
if key not in model_sd:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
if hooks is not None:
model_sd = self.model_state_dict()
memory_counter = None
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
minimum=comfy.model_management.minimum_inference_memory())
# if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
for key in cached_weights:
if key not in model_sd:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
else:
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None
if len(relevant_patches) > 0:
original_weights = self.get_key_patches()
for key in relevant_patches:
if key not in model_sd:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):

View File

@ -96,6 +96,7 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None

View File

@ -76,10 +76,9 @@ class PairConditioningSetPropertiesAndCombine:
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
positive_NEW, negative_NEW = comfy.hooks.set_mask_conds(conds=[positive_NEW, negative_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
final_positive, final_negative = comfy.hooks.combine_with_new_conds(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW])
final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
return (final_positive, final_negative)
class ConditioningSetProperties:
@ -138,10 +137,9 @@ class ConditioningSetPropertiesAndCombine:
def set_properties(self, cond, cond_NEW,
strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
(cond_NEW,) = comfy.hooks.set_mask_conds(conds=[cond_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
(final_cond,) = comfy.hooks.combine_with_new_conds(conds=[cond], new_conds=[cond_NEW])
(final_cond,) = comfy.hooks.set_mask_and_combine_conds(conds=[cond], new_conds=[cond_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
return (final_cond,)
class PairConditioningCombine: