From 37a08a41b3861e3b52dd69ff6b7fb00ba6b43758 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 13 Jun 2024 17:12:50 -0400 Subject: [PATCH] Support setting weight offsets in weight patcher. --- comfy/model_patcher.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 84592f93..2f80ae2b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -209,11 +209,18 @@ class ModelPatcher: p = set() model_sd = self.model.state_dict() for k in patches: - if k in model_sd: + offset = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + + if key in model_sd: p.add(k) - current_patches = self.patches.get(k, []) - current_patches.append((strength_patch, patches[k], strength_model)) - self.patches[k] = current_patches + current_patches = self.patches.get(key, []) + current_patches.append((strength_patch, patches[k], strength_model, offset)) + self.patches[key] = current_patches self.patches_uuid = uuid.uuid4() return list(p) @@ -339,6 +346,12 @@ class ModelPatcher: strength = p[0] v = p[1] strength_model = p[2] + offset = p[3] + + old_weight = None + if offset is not None: + old_weight = weight + weight = weight.narrow(offset[0], offset[1], offset[2]) if strength_model != 1.0: weight *= strength_model @@ -488,6 +501,9 @@ class ModelPatcher: else: logging.warning("patch type not recognized {} {}".format(patch_type, key)) + if old_weight is not None: + weight = old_weight + return weight def unpatch_model(self, device_to=None, unpatch_weights=True):