diff --git a/input/latents/_input_latents_will_be_put_here b/input/latents/_input_latents_will_be_put_here new file mode 100644 index 00000000..e69de29b diff --git a/nodes.py b/nodes.py index e8b36c24..a2c7713a 100644 --- a/nodes.py +++ b/nodes.py @@ -29,6 +29,8 @@ import importlib import folder_paths +import safetensors.torch as sft + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -246,6 +248,91 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class SaveLatent: + def __init__(self): + self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents") + self.type = "output" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = () + FUNCTION = "save" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(self.output_dir, subfolder) + + if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: + print("Saving latent outside the 'input/latents' folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + + # support save metadata for latent sharing + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"workflow": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + file = f"{filename}_{counter:05}_.latent" + file = os.path.join(full_output_folder, file) + + sft.save_file(samples, file, metadata=metadata) + + return {} + + +class LoadLatent: + input_dir = os.path.join(folder_paths.get_input_directory(), "latents") + + @classmethod + def INPUT_TYPES(s): + files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")] + return {"required": {"latent": [sorted(files), ]}, } + + CATEGORY = "_for_testing" + + RETURN_TYPES = ("LATENT", ) + FUNCTION = "load" + + def load(self, latent): + file = folder_paths.get_annotated_filepath(latent, self.input_dir) + + latent = sft.load_file(file, device="cpu") + + return (latent, ) + + class CheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -1235,6 +1322,9 @@ NODE_CLASS_MAPPINGS = { "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, + + "LoadLatent": LoadLatent, + "SaveLatent": SaveLatent } NODE_DISPLAY_NAME_MAPPINGS = {