mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-24 00:17:02 +08:00
Made patch_hook_weight_to_device respect set_func and convert_func
This commit is contained in:
parent
815c6f36e1
commit
8b2c324cf6
@ -1031,7 +1031,9 @@ class ModelPatcher:
|
||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||
if key not in combined_patches:
|
||||
return
|
||||
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
|
||||
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
weight: torch.Tensor
|
||||
if key not in self.hook_backup:
|
||||
target_device = self.offload_device
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
@ -1040,11 +1042,19 @@ class ModelPatcher:
|
||||
target_device = weight.device
|
||||
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
||||
# TODO: properly handle LowVramPatch, if it ends up an issue
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
|
||||
if convert_func is not None:
|
||||
temp_weight = convert_func(temp_weight, inplace=True)
|
||||
|
||||
out_weight = comfy.lora.calculate_weight(combined_patches[key],
|
||||
comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True),
|
||||
temp_weight,
|
||||
key, original_weights=original_weights)
|
||||
del original_weights[key]
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: disable caching if not enough system RAM to do so
|
||||
target_device = self.offload_device
|
||||
@ -1053,9 +1063,9 @@ class ModelPatcher:
|
||||
target_device = weight.device
|
||||
self.cached_hook_patches.setdefault(hooks, {})
|
||||
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
del weight
|
||||
del temp_weight
|
||||
del out_weight
|
||||
del weight
|
||||
|
||||
def unpatch_hooks(self) -> None:
|
||||
with self.use_ejected():
|
||||
|
Loading…
x
Reference in New Issue
Block a user