From 91ed2815d542c96fdad75edba2205140de3cbba6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 14 Jul 2023 02:37:30 -0400 Subject: [PATCH] Add a node to merge CLIP models. --- comfy/sd.py | 7 +++++-- comfy_extras/nodes_model_merging.py | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 4bc9a15fa..bef4e8ef1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -479,8 +479,8 @@ class CLIP: def load_from_state_dict(self, sd): self.cond_stage_model.load_sd(sd) - def add_patches(self, patches, strength=1.0): - return self.patcher.add_patches(patches, strength) + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): + return self.patcher.add_patches(patches, strength_patch, strength_model) def clip_layer(self, layer_idx): self.layer_idx = layer_idx @@ -514,6 +514,9 @@ class CLIP: def unpatch_model(self): self.patcher.unpatch_model() + def get_key_patches(self): + return self.patcher.get_key_patches() + class VAE: def __init__(self, ckpt_path=None, device=None, config=None): if config is None: diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index eae9b6fdb..95c4cfece 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -23,6 +23,27 @@ class ModelMergeSimple: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) +class CLIPMergeSimple: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2, ratio): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) + return (m, ) + class ModelMergeBlocks: @classmethod def INPUT_TYPES(s): @@ -94,4 +115,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, "ModelMergeBlocks": ModelMergeBlocks, "CheckpointSave": CheckpointSave, + "CLIPMergeSimple": CLIPMergeSimple, }