Merge 1ffed0f41cd733237f409afaba0a9ca40c8b50b5 into aee2908d0395577a6e2e13d1307aaf271424108b

This commit is contained in:
catboxanon 2025-05-17 23:11:03 -04:00 committed by GitHub
commit 0609bc160b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 3 deletions

View File

@ -479,6 +479,15 @@ class VAE:
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.png_chunks = {}
if metadata is not None:
meta_color_space = metadata.get("modelspec.color_space")
if str(meta_color_space).lower().startswith("cicp:"):
cicp_chunk = meta_color_space.split("cicp:")[-1].split(",")
cicp_chunk = bytes([1 if b.lower() == 'true' else 0 if b.lower() == 'false' else int(b) for b in cicp_chunk])
self.png_chunks[b"cICP"] = cicp_chunk
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))

View File

@ -286,10 +286,12 @@ class VAEDecode:
CATEGORY = "latent"
DESCRIPTION = "Decodes latent images back into pixel space images."
def decode(self, vae, samples):
def decode(self, vae: comfy.sd.VAE, samples):
images = vae.decode(samples["samples"])
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
if vae.png_chunks is not None:
images.png_chunks = vae.png_chunks
return (images, )
class VAEDecodeTiled:
@ -772,7 +774,8 @@ class VAELoader:
else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
metadata = json.loads(comfy.utils.safetensors_header(vae_path, max_size=1024*1024) or "{}").get("__metadata__")
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid()
return (vae,)
@ -1600,7 +1603,9 @@ class SaveImage:
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
if hasattr(images, "png_chunks"):
for name, data in images.png_chunks.items():
metadata.add(name, data)
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)