mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add unfinished ImageOnlyCheckpointSave node to save a SVD checkpoint.
This node is unfinished, SVD checkpoints saved with this node will work with ComfyUI but not with anything else.
This commit is contained in:
parent
fad02dc2df
commit
d76a04b6ea
@ -1,4 +1,4 @@
|
|||||||
from .utils import load_torch_file, transformers_convert, common_upscale
|
from .utils import load_torch_file, transformers_convert, common_upscale, state_dict_prefix_replace
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import contextlib
|
import contextlib
|
||||||
@ -41,9 +41,13 @@ class ClipVisionModel():
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.model.load_state_dict(sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.model.state_dict()
|
||||||
|
|
||||||
def encode_image(self, image):
|
def encode_image(self, image):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
||||||
@ -76,6 +80,9 @@ def convert_to_transformers(sd, prefix):
|
|||||||
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
|
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
|
||||||
|
|
||||||
sd = transformers_convert(sd, prefix, "vision_model.", 48)
|
sd = transformers_convert(sd, prefix, "vision_model.", 48)
|
||||||
|
else:
|
||||||
|
replace_prefix = {prefix: ""}
|
||||||
|
sd = state_dict_prefix_replace(sd, replace_prefix)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||||
|
@ -179,19 +179,28 @@ class BaseModel(torch.nn.Module):
|
|||||||
def process_latent_out(self, latent):
|
def process_latent_out(self, latent):
|
||||||
return self.latent_format.process_out(latent)
|
return self.latent_format.process_out(latent)
|
||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
extra_sds = []
|
||||||
|
if clip_state_dict is not None:
|
||||||
|
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
||||||
|
if vae_state_dict is not None:
|
||||||
|
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
||||||
|
if clip_vision_state_dict is not None:
|
||||||
|
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||||
|
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
|
||||||
if self.get_dtype() == torch.float16:
|
if self.get_dtype() == torch.float16:
|
||||||
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
|
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
|
||||||
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
|
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
unet_state_dict["v_pred"] = torch.tensor([])
|
unet_state_dict["v_pred"] = torch.tensor([])
|
||||||
|
|
||||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
for sd in extra_sds:
|
||||||
|
unet_state_dict.update(sd)
|
||||||
|
|
||||||
|
return unet_state_dict
|
||||||
|
|
||||||
def set_inpaint(self):
|
def set_inpaint(self):
|
||||||
self.inpaint_model = True
|
self.inpaint_model = True
|
||||||
|
13
comfy/sd.py
13
comfy/sd.py
@ -534,7 +534,14 @@ def load_unet(unet_path):
|
|||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
|
||||||
model_management.load_models_gpu([model, clip.load_model()])
|
clip_sd = None
|
||||||
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
load_models = [model]
|
||||||
|
if clip is not None:
|
||||||
|
load_models.append(clip.load_model())
|
||||||
|
clip_sd = clip.get_sd()
|
||||||
|
|
||||||
|
model_management.load_models_gpu(load_models)
|
||||||
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
|
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||||
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
|
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||||
|
@ -65,6 +65,12 @@ class BASE:
|
|||||||
replace_prefix = {"": "cond_stage_model."}
|
replace_prefix = {"": "cond_stage_model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
if self.clip_vision_prefix is not None:
|
||||||
|
replace_prefix[""] = self.clip_vision_prefix
|
||||||
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def process_unet_state_dict_for_saving(self, state_dict):
|
def process_unet_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {"": "model.diffusion_model."}
|
replace_prefix = {"": "model.diffusion_model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
@ -119,25 +119,8 @@ class ModelMergeBlocks:
|
|||||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
class CheckpointSave:
|
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
|
||||||
def __init__(self):
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": { "model": ("MODEL",),
|
|
||||||
"clip": ("CLIP",),
|
|
||||||
"vae": ("VAE",),
|
|
||||||
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
|
||||||
|
|
||||||
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
|
||||||
prompt_info = ""
|
prompt_info = ""
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompt_info = json.dumps(prompt)
|
prompt_info = json.dumps(prompt)
|
||||||
@ -176,7 +159,27 @@ class CheckpointSave:
|
|||||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
|
||||||
|
|
||||||
|
class CheckpointSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"clip": ("CLIP",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
|
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
class CLIPSave:
|
class CLIPSave:
|
||||||
|
@ -3,6 +3,7 @@ import torch
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
import comfy_extras.nodes_model_merging
|
||||||
|
|
||||||
|
|
||||||
class ImageOnlyCheckpointLoader:
|
class ImageOnlyCheckpointLoader:
|
||||||
@ -78,10 +79,26 @@ class VideoLinearCFGGuidance:
|
|||||||
m.set_model_sampler_cfg_function(linear_cfg)
|
m.set_model_sampler_cfg_function(linear_cfg)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"clip_vision": ("CLIP_VISION",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||||
|
|
||||||
|
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||||
|
return {}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||||
|
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
Loading…
Reference in New Issue
Block a user