From 3d0d386736f15fce9b6afa615b3b845c7a8ec984 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 5 Jan 2025 15:01:25 -0500 Subject: [PATCH 1/4] Document get_attr and get_model_object --- comfy/model_patcher.py | 17 ++++++++++++++++- comfy/utils.py | 20 +++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4597ce11ccf..55de557ae25 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -402,7 +402,22 @@ def set_model_forward_timestep_embed_patch(self, patch): def add_object_patch(self, name, obj): self.object_patches[name] = obj - def get_model_object(self, name): + def get_model_object(self, name: str) -> torch.nn.Module: + """Retrieves a nested attribute from an object using dot notation considering + object patches. + + Args: + obj: The object to get the attribute from + attr (str): The attribute path using dot notation (e.g. "model.layer.weight") + + Returns: + The value of the requested attribute + + Example: + model = MyModel() + weight = get_attr(model, "layer1.conv.weight") + # Equivalent to: model.layer1.conv.weight + """ if name in self.object_patches: return self.object_patches[name] else: diff --git a/comfy/utils.py b/comfy/utils.py index ea666ae5b46..b486b2deb1b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -693,7 +693,25 @@ def copy_to_param(obj, attr, value): prev = getattr(obj, attrs[-1]) prev.data.copy_(value) -def get_attr(obj, attr): +def get_attr(obj, attr: str): + """Retrieves a nested attribute from an object using dot notation. + + Args: + obj: The object to get the attribute from + attr (str): The attribute path using dot notation (e.g. "model.layer.weight") + + Returns: + The value of the requested attribute + + Example: + model = MyModel() + weight = get_attr(model, "layer1.conv.weight") + # Equivalent to: model.layer1.conv.weight + + Important: + Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when + accessing nested model objects under `ModelPatcher.model`. + """ attrs = attr.split(".") for name in attrs: obj = getattr(obj, name) From d24e99bcbeb2dc461c3bdcf86cf4c745185bd2e0 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Mon, 6 Jan 2025 10:57:49 -0500 Subject: [PATCH 2/4] Update model_patcher.py --- comfy/model_patcher.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 55de557ae25..4176a806fcc 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -407,8 +407,7 @@ def get_model_object(self, name: str) -> torch.nn.Module: object patches. Args: - obj: The object to get the attribute from - attr (str): The attribute path using dot notation (e.g. "model.layer.weight") + name (str): The attribute path using dot notation (e.g. "model.layer.weight") Returns: The value of the requested attribute From a60520127874c2797a1b4298a3b59d9c8a6558f0 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Mon, 6 Jan 2025 19:08:16 -0500 Subject: [PATCH 3/4] Update model_patcher.py --- comfy/model_patcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4176a806fcc..5c53e6c9a55 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -413,9 +413,9 @@ def get_model_object(self, name: str) -> torch.nn.Module: The value of the requested attribute Example: - model = MyModel() - weight = get_attr(model, "layer1.conv.weight") - # Equivalent to: model.layer1.conv.weight + patcher = ModelPatcher() + weight = patcher.get_model_object("layer1.conv.weight") + # Equivalent to: patcher.model.layer1.conv.weight """ if name in self.object_patches: return self.object_patches[name] From fb281b3c528529c8bcc5b0fdf8a143788b4105bf Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Mon, 6 Jan 2025 19:08:50 -0500 Subject: [PATCH 4/4] Update model_patcher.py --- comfy/model_patcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5c53e6c9a55..e886bdbb712 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -415,7 +415,6 @@ def get_model_object(self, name: str) -> torch.nn.Module: Example: patcher = ModelPatcher() weight = patcher.get_model_object("layer1.conv.weight") - # Equivalent to: patcher.model.layer1.conv.weight """ if name in self.object_patches: return self.object_patches[name]