mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
utils.set_attr can now be used to set any attribute.
The old set_attr has been renamed to set_attr_param.
This commit is contained in:
parent
dce3555339
commit
1abf8374ec
@ -287,13 +287,13 @@ class ControlLora(ControlNet):
|
|||||||
for k in sd:
|
for k in sd:
|
||||||
weight = sd[k]
|
weight = sd[k]
|
||||||
try:
|
try:
|
||||||
comfy.utils.set_attr(self.control_model, k, weight)
|
comfy.utils.set_attr_param(self.control_model, k, weight)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for k in self.control_weights:
|
for k in self.control_weights:
|
||||||
if k not in {"lora_controlnet"}:
|
if k not in {"lora_controlnet"}:
|
||||||
comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
||||||
|
@ -176,10 +176,9 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def patch_model(self, device_to=None, patch_weights=True):
|
def patch_model(self, device_to=None, patch_weights=True):
|
||||||
for k in self.object_patches:
|
for k in self.object_patches:
|
||||||
old = comfy.utils.get_attr(self.model, k)
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||||
if k not in self.object_patches_backup:
|
if k not in self.object_patches_backup:
|
||||||
self.object_patches_backup[k] = old
|
self.object_patches_backup[k] = old
|
||||||
comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
|
||||||
|
|
||||||
if patch_weights:
|
if patch_weights:
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
@ -203,7 +202,7 @@ class ModelPatcher:
|
|||||||
if inplace_update:
|
if inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr(self.model, key, out_weight)
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
del temp_weight
|
del temp_weight
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
@ -342,7 +341,7 @@ class ModelPatcher:
|
|||||||
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
||||||
else:
|
else:
|
||||||
for k in keys:
|
for k in keys:
|
||||||
comfy.utils.set_attr(self.model, k, self.backup[k])
|
comfy.utils.set_attr_param(self.model, k, self.backup[k])
|
||||||
|
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
|
||||||
|
@ -294,8 +294,11 @@ def set_attr(obj, attr, value):
|
|||||||
for name in attrs[:-1]:
|
for name in attrs[:-1]:
|
||||||
obj = getattr(obj, name)
|
obj = getattr(obj, name)
|
||||||
prev = getattr(obj, attrs[-1])
|
prev = getattr(obj, attrs[-1])
|
||||||
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
setattr(obj, attrs[-1], value)
|
||||||
del prev
|
return prev
|
||||||
|
|
||||||
|
def set_attr_param(obj, attr, value):
|
||||||
|
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||||
|
|
||||||
def copy_to_param(obj, attr, value):
|
def copy_to_param(obj, attr, value):
|
||||||
# inplace update tensor instead of replacing it
|
# inplace update tensor instead of replacing it
|
||||||
|
Loading…
Reference in New Issue
Block a user