Add unfinished ImageOnlyCheckpointSave node to save a SVD checkpoint.

This node is unfinished, SVD checkpoints saved with this node will
work with ComfyUI but not with anything else.
This commit is contained in:
comfyanonymous 2024-01-17 19:37:19 -05:00
parent fad02dc2df
commit d76a04b6ea
6 changed files with 99 additions and 50 deletions

View File

@ -1,4 +1,4 @@
from .utils import load_torch_file, transformers_convert, common_upscale from .utils import load_torch_file, transformers_convert, common_upscale, state_dict_prefix_replace
import os import os
import torch import torch
import contextlib import contextlib
@ -41,9 +41,13 @@ class ClipVisionModel():
self.model.eval() self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd): def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False) return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image): def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device)).float() pixel_values = clip_preprocess(image.to(self.load_device)).float()
@ -76,6 +80,9 @@ def convert_to_transformers(sd, prefix):
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, prefix, "vision_model.", 48) sd = transformers_convert(sd, prefix, "vision_model.", 48)
else:
replace_prefix = {prefix: ""}
sd = state_dict_prefix_replace(sd, replace_prefix)
return sd return sd
def load_clipvision_from_sd(sd, prefix="", convert_keys=False): def load_clipvision_from_sd(sd, prefix="", convert_keys=False):

View File

@ -179,19 +179,28 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent): def process_latent_out(self, latent):
return self.latent_format.process_out(latent) return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
if vae_state_dict is not None:
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16: if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
if self.model_type == ModelType.V_PREDICTION: if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([]) unet_state_dict["v_pred"] = torch.tensor([])
return {**unet_state_dict, **vae_state_dict, **clip_state_dict} for sd in extra_sds:
unet_state_dict.update(sd)
return unet_state_dict
def set_inpaint(self): def set_inpaint(self):
self.inpaint_model = True self.inpaint_model = True

View File

@ -534,7 +534,14 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
model_management.load_models_gpu([model, clip.load_model()]) clip_sd = None
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
comfy.utils.save_torch_file(sd, output_path, metadata=metadata) comfy.utils.save_torch_file(sd, output_path, metadata=metadata)

View File

@ -65,6 +65,12 @@ class BASE:
replace_prefix = {"": "cond_stage_model."} replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_clip_vision_state_dict_for_saving(self, state_dict):
replace_prefix = {}
if self.clip_vision_prefix is not None:
replace_prefix[""] = self.clip_vision_prefix
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_unet_state_dict_for_saving(self, state_dict): def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.diffusion_model."} replace_prefix = {"": "model.diffusion_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -119,25 +119,8 @@ class ModelMergeBlocks:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, ) return (m, )
class CheckpointSave: def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
def __init__(self): full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = "" prompt_info = ""
if prompt is not None: if prompt is not None:
prompt_info = json.dumps(prompt) prompt_info = json.dumps(prompt)
@ -176,7 +159,27 @@ class CheckpointSave:
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
class CheckpointSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {} return {}
class CLIPSave: class CLIPSave:

View File

@ -3,6 +3,7 @@ import torch
import comfy.utils import comfy.utils
import comfy.sd import comfy.sd
import folder_paths import folder_paths
import comfy_extras.nodes_model_merging
class ImageOnlyCheckpointLoader: class ImageOnlyCheckpointLoader:
@ -78,10 +79,26 @@ class VideoLinearCFGGuidance:
m.set_model_sampler_cfg_function(linear_cfg) m.set_model_sampler_cfg_function(linear_cfg)
return (m, ) return (m, )
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance, "VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {