diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 416197586..f859a50d4 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -287,13 +287,13 @@ class ControlLora(ControlNet): for k in sd: weight = sd[k] try: - comfy.utils.set_attr(self.control_model, k, weight) + comfy.utils.set_attr_param(self.control_model, k, weight) except: pass for k in self.control_weights: if k not in {"lora_controlnet"}: - comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f29781f31..604e34779 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -176,10 +176,9 @@ class ModelPatcher: def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: - old = comfy.utils.get_attr(self.model, k) + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) if k not in self.object_patches_backup: self.object_patches_backup[k] = old - comfy.utils.set_attr(self.model, k, self.object_patches[k]) if patch_weights: model_sd = self.model_state_dict() @@ -203,7 +202,7 @@ class ModelPatcher: if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: - comfy.utils.set_attr(self.model, key, out_weight) + comfy.utils.set_attr_param(self.model, key, out_weight) del temp_weight if device_to is not None: @@ -342,7 +341,7 @@ class ModelPatcher: comfy.utils.copy_to_param(self.model, k, self.backup[k]) else: for k in keys: - comfy.utils.set_attr(self.model, k, self.backup[k]) + comfy.utils.set_attr_param(self.model, k, self.backup[k]) self.backup = {} diff --git a/comfy/utils.py b/comfy/utils.py index 41f730c8e..5deb14cd2 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -294,8 +294,11 @@ def set_attr(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) - del prev + setattr(obj, attrs[-1], value) + return prev + +def set_attr_param(obj, attr, value): + return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it