utils.set_attr can now be used to set any attribute.

The old set_attr has been renamed to set_attr_param.
This commit is contained in:
comfyanonymous 2024-03-02 17:27:23 -05:00
parent dce3555339
commit 1abf8374ec
3 changed files with 10 additions and 8 deletions

View File

@ -287,13 +287,13 @@ class ControlLora(ControlNet):
for k in sd:
weight = sd[k]
try:
comfy.utils.set_attr(self.control_model, k, weight)
comfy.utils.set_attr_param(self.control_model, k, weight)
except:
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)

View File

@ -176,10 +176,9 @@ class ModelPatcher:
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = comfy.utils.get_attr(self.model, k)
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
comfy.utils.set_attr(self.model, k, self.object_patches[k])
if patch_weights:
model_sd = self.model_state_dict()
@ -203,7 +202,7 @@ class ModelPatcher:
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr(self.model, key, out_weight)
comfy.utils.set_attr_param(self.model, key, out_weight)
del temp_weight
if device_to is not None:
@ -342,7 +341,7 @@ class ModelPatcher:
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
comfy.utils.set_attr_param(self.model, k, self.backup[k])
self.backup = {}

View File

@ -294,8 +294,11 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev
setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it