Document get_attr and get_model_object

This commit is contained in:
huchenlei 2025-01-05 15:01:25 -05:00
parent 7da85fac3f
commit 3d0d386736
2 changed files with 35 additions and 2 deletions

View File

@ -402,7 +402,22 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def get_model_object(self, name):
def get_model_object(self, name: str) -> torch.nn.Module:
"""Retrieves a nested attribute from an object using dot notation considering
object patches.
Args:
obj: The object to get the attribute from
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
Returns:
The value of the requested attribute
Example:
model = MyModel()
weight = get_attr(model, "layer1.conv.weight")
# Equivalent to: model.layer1.conv.weight
"""
if name in self.object_patches:
return self.object_patches[name]
else:

View File

@ -693,7 +693,25 @@ def copy_to_param(obj, attr, value):
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
def get_attr(obj, attr: str):
"""Retrieves a nested attribute from an object using dot notation.
Args:
obj: The object to get the attribute from
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
Returns:
The value of the requested attribute
Example:
model = MyModel()
weight = get_attr(model, "layer1.conv.weight")
# Equivalent to: model.layer1.conv.weight
Important:
Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
accessing nested model objects under `ModelPatcher.model`.
"""
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)