diff --git a/comfy/sd.py b/comfy/sd.py index e98a3aa87..24c07dbde 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -479,6 +479,15 @@ class VAE: self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() + self.png_chunks = {} + + if metadata is not None: + meta_color_space = metadata.get("modelspec.color_space") + if str(meta_color_space).lower().startswith("cicp:"): + cicp_chunk = meta_color_space.split("cicp:")[-1].split(",") + cicp_chunk = bytes([1 if b.lower() == 'true' else 0 if b.lower() == 'false' else int(b) for b in cicp_chunk]) + self.png_chunks[b"cICP"] = cicp_chunk + self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) diff --git a/nodes.py b/nodes.py index 95e831b8b..920a800b7 100644 --- a/nodes.py +++ b/nodes.py @@ -286,10 +286,12 @@ class VAEDecode: CATEGORY = "latent" DESCRIPTION = "Decodes latent images back into pixel space images." - def decode(self, vae, samples): + def decode(self, vae: comfy.sd.VAE, samples): images = vae.decode(samples["samples"]) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) + if vae.png_chunks is not None: + images.png_chunks = vae.png_chunks return (images, ) class VAEDecodeTiled: @@ -772,7 +774,8 @@ class VAELoader: else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) - vae = comfy.sd.VAE(sd=sd) + metadata = json.loads(comfy.utils.safetensors_header(vae_path, max_size=1024*1024) or "{}").get("__metadata__") + vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() return (vae,) @@ -1600,7 +1603,9 @@ class SaveImage: if extra_pnginfo is not None: for x in extra_pnginfo: metadata.add_text(x, json.dumps(extra_pnginfo[x])) - + if hasattr(images, "png_chunks"): + for name, data in images.png_chunks.items(): + metadata.add(name, data) filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) file = f"{filename_with_batch_num}_{counter:05}_.png" img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)