Made patch_hook_weight_to_device respect set_func and convert_func

This commit is contained in:
Jedrzej Kosinski 2024-11-24 13:48:30 -06:00
parent 815c6f36e1
commit 8b2c324cf6

View File

@ -1031,7 +1031,9 @@ class ModelPatcher:
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:
return
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
weight, set_func, convert_func = get_key_weight(self.model, key)
weight: torch.Tensor
if key not in self.hook_backup:
target_device = self.offload_device
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
@ -1040,11 +1042,19 @@ class ModelPatcher:
target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
# TODO: properly handle LowVramPatch, if it ends up an issue
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
out_weight = comfy.lora.calculate_weight(combined_patches[key],
comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True),
temp_weight,
key, original_weights=original_weights)
del original_weights[key]
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: disable caching if not enough system RAM to do so
target_device = self.offload_device
@ -1053,9 +1063,9 @@ class ModelPatcher:
target_device = weight.device
self.cached_hook_patches.setdefault(hooks, {})
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
comfy.utils.copy_to_param(self.model, key, out_weight)
del weight
del temp_weight
del out_weight
del weight
def unpatch_hooks(self) -> None:
with self.use_ejected():