mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-29 10:57:02 +08:00
Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
This commit is contained in:
parent
9330745f27
commit
1766d903ad
@ -215,6 +215,7 @@ class ModelPatcher:
|
|||||||
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
||||||
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
||||||
self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP
|
self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP
|
||||||
|
self.is_clip = False
|
||||||
# TODO: hook_mode should be entirely removed; behavior should be determined by remaining VRAM/memory
|
# TODO: hook_mode should be entirely removed; behavior should be determined by remaining VRAM/memory
|
||||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||||
|
|
||||||
@ -291,6 +292,7 @@ class ModelPatcher:
|
|||||||
n.hook_backup = self.hook_backup
|
n.hook_backup = self.hook_backup
|
||||||
n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
|
n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
|
||||||
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
||||||
|
n.is_clip = self.is_clip
|
||||||
n.hook_mode = self.hook_mode
|
n.hook_mode = self.hook_mode
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
@ -613,7 +615,7 @@ class ModelPatcher:
|
|||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||||
|
|
||||||
self.apply_hooks(self.forced_hooks)
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||||
|
|
||||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
@ -733,6 +735,7 @@ class ModelPatcher:
|
|||||||
self.patch_model(load_weights=False)
|
self.patch_model(load_weights=False)
|
||||||
full_load = False
|
full_load = False
|
||||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||||
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||||
return 0
|
return 0
|
||||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
full_load = True
|
full_load = True
|
||||||
@ -991,9 +994,9 @@ class ModelPatcher:
|
|||||||
combined_patches[key] = current_patches
|
combined_patches[key] = current_patches
|
||||||
return combined_patches
|
return combined_patches
|
||||||
|
|
||||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, model_options: dict=None):
|
def apply_hooks(self, hooks: comfy.hooks.HookGroup, model_options: dict=None, force_apply=False):
|
||||||
# TODO: return transformer_options dict with any additions from hooks
|
# TODO: return transformer_options dict with any additions from hooks
|
||||||
if self.current_hooks == hooks:
|
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||||
return {}
|
return {}
|
||||||
self.patch_hooks(hooks=hooks)
|
self.patch_hooks(hooks=hooks)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||||
@ -1003,6 +1006,7 @@ class ModelPatcher:
|
|||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
if hooks is not None:
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
memory_counter = None
|
memory_counter = None
|
||||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||||
|
@ -96,6 +96,7 @@ class CLIP:
|
|||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
|
self.patcher.is_clip = True
|
||||||
if params['device'] == load_device:
|
if params['device'] == load_device:
|
||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
|
@ -76,10 +76,9 @@ class PairConditioningSetPropertiesAndCombine:
|
|||||||
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
|
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
|
||||||
strength: float, set_cond_area: str,
|
strength: float, set_cond_area: str,
|
||||||
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
|
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
|
||||||
positive_NEW, negative_NEW = comfy.hooks.set_mask_conds(conds=[positive_NEW, negative_NEW],
|
final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
|
||||||
strength=strength, set_cond_area=set_cond_area,
|
strength=strength, set_cond_area=set_cond_area,
|
||||||
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
|
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
|
||||||
final_positive, final_negative = comfy.hooks.combine_with_new_conds(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW])
|
|
||||||
return (final_positive, final_negative)
|
return (final_positive, final_negative)
|
||||||
|
|
||||||
class ConditioningSetProperties:
|
class ConditioningSetProperties:
|
||||||
@ -138,10 +137,9 @@ class ConditioningSetPropertiesAndCombine:
|
|||||||
def set_properties(self, cond, cond_NEW,
|
def set_properties(self, cond, cond_NEW,
|
||||||
strength: float, set_cond_area: str,
|
strength: float, set_cond_area: str,
|
||||||
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
|
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
|
||||||
(cond_NEW,) = comfy.hooks.set_mask_conds(conds=[cond_NEW],
|
(final_cond,) = comfy.hooks.set_mask_and_combine_conds(conds=[cond], new_conds=[cond_NEW],
|
||||||
strength=strength, set_cond_area=set_cond_area,
|
strength=strength, set_cond_area=set_cond_area,
|
||||||
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
|
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
|
||||||
(final_cond,) = comfy.hooks.combine_with_new_conds(conds=[cond], new_conds=[cond_NEW])
|
|
||||||
return (final_cond,)
|
return (final_cond,)
|
||||||
|
|
||||||
class PairConditioningCombine:
|
class PairConditioningCombine:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user