Support setting weight offsets in weight patcher.

This commit is contained in:
comfyanonymous 2024-06-13 17:12:50 -04:00
parent 605e64f6d3
commit 37a08a41b3

View File

@ -209,11 +209,18 @@ class ModelPatcher:
p = set()
model_sd = self.model.state_dict()
for k in patches:
if k in model_sd:
offset = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if key in model_sd:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
current_patches = self.patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset))
self.patches[key] = current_patches
self.patches_uuid = uuid.uuid4()
return list(p)
@ -339,6 +346,12 @@ class ModelPatcher:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
@ -488,6 +501,9 @@ class ModelPatcher:
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight
def unpatch_model(self, device_to=None, unpatch_weights=True):