Add CheckpointSave node to save checkpoints.

The created checkpoints contain workflow metadata that can be loaded by
dragging them on top of the UI or loading them with the "Load" button.

Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI
is using for inference on your hardware. To force fp32 use: --force-fp32

Anything that patches the model weights like merging or loras will be
saved.

The output directory is currently set to: output/checkpoints but that might
change in the future.
This commit is contained in:
comfyanonymous 2023-06-26 12:21:07 -04:00
parent b72a7a835a
commit 9b93b920be
12 changed files with 147 additions and 13 deletions

View File

@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys()))
code2idx = {"q": 0, "k": 1, "v": 2} code2idx = {"q": 0, "k": 1, "v": 2}
def convert_text_enc_state_dict_v20(text_enc_dict): def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {} new_state_dict = {}
capture_qkv_weight = {} capture_qkv_weight = {}
capture_qkv_bias = {} capture_qkv_bias = {}
for k, v in text_enc_dict.items(): for k, v in text_enc_dict.items():
if not k.startswith(prefix):
continue
if ( if (
k.endswith(".self_attn.q_proj.weight") k.endswith(".self_attn.q_proj.weight")
or k.endswith(".self_attn.k_proj.weight") or k.endswith(".self_attn.k_proj.weight")

View File

@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np import numpy as np
from . import utils
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
def __init__(self, model_config, v_prediction=False): def __init__(self, model_config, v_prediction=False):
@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config unet_config = model_config.unet_config
self.latent_format = model_config.latent_format self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config) self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction self.v_prediction = v_prediction
@ -83,6 +85,16 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent): def process_latent_out(self, latent):
return self.latent_format.process_out(latent) return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
class SD21UNCLIP(BaseModel): class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, v_prediction=True): def __init__(self, model_config, noise_aug_config, v_prediction=True):

View File

@ -545,11 +545,11 @@ class CLIP:
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.clip_layer(self.layer_idx)
try: try:
self.patcher.patch_model() self.patch_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.patcher.unpatch_model() self.unpatch_model()
except Exception as e: except Exception as e:
self.patcher.unpatch_model() self.unpatch_model()
raise e raise e
cond_out = cond cond_out = cond
@ -564,6 +564,15 @@ class CLIP:
def load_sd(self, sd): def load_sd(self, sd):
return self.cond_stage_model.load_sd(sd) return self.cond_stage_model.load_sd(sd)
def get_sd(self):
return self.cond_stage_model.state_dict()
def patch_model(self):
self.patcher.patch_model()
def unpatch_model(self):
self.patcher.unpatch_model()
class VAE: class VAE:
def __init__(self, ckpt_path=None, device=None, config=None): def __init__(self, ckpt_path=None, device=None, config=None):
if config is None: if config is None:
@ -665,6 +674,10 @@ class VAE:
self.first_stage_model = self.first_stage_model.cpu() self.first_stage_model = self.first_stage_model.cpu()
return samples return samples
def get_sd(self):
return self.first_stage_model.state_dict()
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0] current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size) #print(current_batch_size, target_batch_size)
@ -1135,3 +1148,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("left over keys:", left_over) print("left over keys:", left_over)
return (ModelPatcher(model), clip, vae, clipvision) return (ModelPatcher(model), clip, vae, clipvision)
def save_checkpoint(output_path, model, clip, vae, metadata=None):
try:
model.patch_model()
clip.patch_model()
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
utils.save_torch_file(sd, output_path, metadata=metadata)
model.unpatch_model()
clip.unpatch_model()
except Exception as e:
model.unpatch_model()
clip.unpatch_model()
raise e

View File

@ -9,6 +9,8 @@ from . import sdxl_clip
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
from . import diffusers_convert
class SD15(supported_models_base.BASE): class SD15(supported_models_base.BASE):
unet_config = { unet_config = {
"context_dim": 768, "context_dim": 768,
@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix[""] = "cond_stage_model.model."
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
def clip_target(self): def clip_target(self):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
@ -113,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE):
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self): def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
@ -142,6 +158,19 @@ class SDXL(supported_models_base.BASE):
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
for k in state_dict:
if k.startswith("clip_l"):
state_dict_g[k] = state_dict[k]
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
replace_prefix["clip_l"] = "conditioner.embedders.0"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self): def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)

View File

@ -64,3 +64,15 @@ class BASE:
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.diffusion_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)
def process_vae_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "first_stage_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -2,10 +2,10 @@ import torch
import math import math
import struct import struct
import comfy.checkpoint_pickle import comfy.checkpoint_pickle
import safetensors.torch
def load_torch_file(ckpt, safe_load=False): def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu") sd = safetensors.torch.load_file(ckpt, device="cpu")
else: else:
if safe_load: if safe_load:
@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False):
sd = pl_sd sd = pl_sd
return sd return sd
def save_torch_file(sd, ckpt, metadata=None):
if metadata is not None:
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
else:
safetensors.torch.save_file(sd, ckpt)
def transformers_convert(sd, prefix_from, prefix_to, number): def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = { keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight", "{}positional_embedding": "{}embeddings.position_embedding.weight",
@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd return sd
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:
state_dict[k] = state_dict[k].to(dtype)
return state_dict
def safetensors_header(safetensors_path, max_size=100*1024*1024): def safetensors_header(safetensors_path, max_size=100*1024*1024):
with open(safetensors_path, "rb") as f: with open(safetensors_path, "rb") as f:
header = f.read(8) header = f.read(8)

View File

@ -1,4 +1,8 @@
import comfy.sd
import comfy.utils
import folder_paths
import json
import os
class ModelMergeSimple: class ModelMergeSimple:
@classmethod @classmethod
@ -49,7 +53,43 @@ class ModelMergeBlocks:
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
return (m, ) return (m, )
class CheckpointSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple, "ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks "ModelMergeBlocks": ModelMergeBlocks,
"CheckpointSave": CheckpointSave,
} }

View File

@ -286,8 +286,7 @@ class SaveLatent:
output["latent_tensor"] = samples["samples"] output["latent_tensor"] = samples["samples"]
output["latent_format_version_0"] = torch.tensor([]) output["latent_format_version_0"] = torch.tensor([])
safetensors.torch.save_file(output, file, metadata=metadata) comfy.utils.save_torch_file(output, file, metadata=metadata)
return {} return {}

View File

@ -144,6 +144,7 @@
"\n", "\n",
"\n", "\n",
"# ESRGAN upscale model\n", "# ESRGAN upscale model\n",
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
"\n", "\n",

View File

@ -1468,7 +1468,7 @@ export class ComfyApp {
this.loadGraphData(JSON.parse(reader.result)); this.loadGraphData(JSON.parse(reader.result));
}; };
reader.readAsText(file); reader.readAsText(file);
} else if (file.name?.endsWith(".latent")) { } else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
const info = await getLatentMetadata(file); const info = await getLatentMetadata(file);
if (info.workflow) { if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow)); this.loadGraphData(JSON.parse(info.workflow));

View File

@ -55,11 +55,12 @@ export function getLatentMetadata(file) {
const dataView = new DataView(safetensorsData.buffer); const dataView = new DataView(safetensorsData.buffer);
let header_size = dataView.getUint32(0, true); let header_size = dataView.getUint32(0, true);
let offset = 8; let offset = 8;
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size)));
r(header.__metadata__); r(header.__metadata__);
}; };
reader.readAsArrayBuffer(file); var slice = file.slice(0, 1024 * 1024 * 4);
reader.readAsArrayBuffer(slice);
}); });
} }

View File

@ -545,7 +545,7 @@ export class ComfyUI {
const fileInput = $el("input", { const fileInput = $el("input", {
id: "comfy-file-input", id: "comfy-file-input",
type: "file", type: "file",
accept: ".json,image/png,.latent", accept: ".json,image/png,.latent,.safetensors",
style: {display: "none"}, style: {display: "none"},
parent: document.body, parent: document.body,
onchange: () => { onchange: () => {