ComfyUI/comfy_extras/nodes_model_merging.py
comfyanonymous 9b93b920be Add CheckpointSave node to save checkpoints.
The created checkpoints contain workflow metadata that can be loaded by
dragging them on top of the UI or loading them with the "Load" button.

Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI
is using for inference on your hardware. To force fp32 use: --force-fp32

Anything that patches the model weights like merging or loras will be
saved.

The output directory is currently set to: output/checkpoints but that might
change in the future.
2023-06-26 12:22:27 -04:00

96 lines
3.3 KiB
Python

import comfy.sd
import comfy.utils
import folder_paths
import json
import os
class ModelMergeSimple:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
CATEGORY = "_for_testing/model_merging"
def merge(self, model1, model2, ratio):
m = model1.clone()
sd = model2.model_state_dict("diffusion_model.")
for k in sd:
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
return (m, )
class ModelMergeBlocks:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
CATEGORY = "_for_testing/model_merging"
def merge(self, model1, model2, **kwargs):
m = model1.clone()
sd = model2.model_state_dict("diffusion_model.")
default_ratio = next(iter(kwargs.values()))
for k in sd:
ratio = default_ratio
k_unet = k[len("diffusion_model."):]
for arg in kwargs:
if k_unet.startswith(arg):
ratio = kwargs[arg]
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
return (m, )
class CheckpointSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks,
"CheckpointSave": CheckpointSave,
}