mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-28 02:17:01 +08:00
Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
This commit is contained in:
parent
9330745f27
commit
1766d903ad
@ -215,6 +215,7 @@ class ModelPatcher:
|
||||
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
||||
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
||||
self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP
|
||||
self.is_clip = False
|
||||
# TODO: hook_mode should be entirely removed; behavior should be determined by remaining VRAM/memory
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
@ -291,6 +292,7 @@ class ModelPatcher:
|
||||
n.hook_backup = self.hook_backup
|
||||
n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
|
||||
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
||||
n.is_clip = self.is_clip
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
@ -613,7 +615,7 @@ class ModelPatcher:
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||
|
||||
self.apply_hooks(self.forced_hooks)
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
with self.use_ejected():
|
||||
@ -733,6 +735,7 @@ class ModelPatcher:
|
||||
self.patch_model(load_weights=False)
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
return 0
|
||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||
full_load = True
|
||||
@ -991,9 +994,9 @@ class ModelPatcher:
|
||||
combined_patches[key] = current_patches
|
||||
return combined_patches
|
||||
|
||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, model_options: dict=None):
|
||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, model_options: dict=None, force_apply=False):
|
||||
# TODO: return transformer_options dict with any additions from hooks
|
||||
if self.current_hooks == hooks:
|
||||
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||
return {}
|
||||
self.patch_hooks(hooks=hooks)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||
@ -1003,31 +1006,32 @@ class ModelPatcher:
|
||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
model_sd = self.model_state_dict()
|
||||
memory_counter = None
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
|
||||
memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
|
||||
minimum=comfy.model_management.minimum_inference_memory())
|
||||
# if have cached weights for hooks, use it
|
||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
for key in cached_weights:
|
||||
if key not in model_sd:
|
||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
else:
|
||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||
original_weights = None
|
||||
if len(relevant_patches) > 0:
|
||||
original_weights = self.get_key_patches()
|
||||
for key in relevant_patches:
|
||||
if key not in model_sd:
|
||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
memory_counter=memory_counter)
|
||||
if hooks is not None:
|
||||
model_sd = self.model_state_dict()
|
||||
memory_counter = None
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
|
||||
memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
|
||||
minimum=comfy.model_management.minimum_inference_memory())
|
||||
# if have cached weights for hooks, use it
|
||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
for key in cached_weights:
|
||||
if key not in model_sd:
|
||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
else:
|
||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||
original_weights = None
|
||||
if len(relevant_patches) > 0:
|
||||
original_weights = self.get_key_patches()
|
||||
for key in relevant_patches:
|
||||
if key not in model_sd:
|
||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
memory_counter=memory_counter)
|
||||
self.current_hooks = hooks
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||
|
@ -96,6 +96,7 @@ class CLIP:
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
|
@ -76,10 +76,9 @@ class PairConditioningSetPropertiesAndCombine:
|
||||
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
|
||||
positive_NEW, negative_NEW = comfy.hooks.set_mask_conds(conds=[positive_NEW, negative_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
|
||||
final_positive, final_negative = comfy.hooks.combine_with_new_conds(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW])
|
||||
final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class ConditioningSetProperties:
|
||||
@ -138,10 +137,9 @@ class ConditioningSetPropertiesAndCombine:
|
||||
def set_properties(self, cond, cond_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
|
||||
(cond_NEW,) = comfy.hooks.set_mask_conds(conds=[cond_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
|
||||
(final_cond,) = comfy.hooks.combine_with_new_conds(conds=[cond], new_conds=[cond_NEW])
|
||||
(final_cond,) = comfy.hooks.set_mask_and_combine_conds(conds=[cond], new_conds=[cond_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
|
||||
return (final_cond,)
|
||||
|
||||
class PairConditioningCombine:
|
||||
|
Loading…
x
Reference in New Issue
Block a user