Support preview images embedded in safetensors metadata

This commit is contained in:
catboxanon 2024-12-19 12:33:19 -05:00
parent c441048a4f
commit 9f6401d8dd

View File

@ -1,10 +1,13 @@
from __future__ import annotations from __future__ import annotations
import os import os
import base64
import json
import time import time
import logging import logging
import folder_paths import folder_paths
import glob import glob
import comfy.utils
from aiohttp import web from aiohttp import web
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -59,13 +62,13 @@ class ModelFileManager:
folder = folders[0][path_index] folder = folders[0][path_index]
full_filename = os.path.join(folder, filename) full_filename = os.path.join(folder, filename)
preview_files = self.get_model_previews(full_filename) previews = self.get_model_previews(full_filename)
default_preview_file = preview_files[0] if len(preview_files) > 0 else None default_preview = previews[0] if len(previews) > 0 else None
if default_preview_file is None or not os.path.isfile(default_preview_file): if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404) return web.Response(status=404)
try: try:
with Image.open(default_preview_file) as img: with Image.open(default_preview) as img:
img_bytes = BytesIO() img_bytes = BytesIO()
img.save(img_bytes, format="WEBP") img.save(img_bytes, format="WEBP")
img_bytes.seek(0) img_bytes.seek(0)
@ -143,7 +146,7 @@ class ModelFileManager:
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str]: def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath) dirname = os.path.dirname(filepath)
if not os.path.exists(dirname): if not os.path.exists(dirname):
@ -152,8 +155,10 @@ class ModelFileManager:
basename = os.path.splitext(filepath)[0] basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False) match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image") image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str] = [] result: list[str | BytesIO] = []
for filename in image_files: for filename in image_files:
_basename = os.path.splitext(filename)[0] _basename = os.path.splitext(filename)[0]
@ -161,6 +166,18 @@ class ModelFileManager:
result.append(filename) result.append(filename)
if _basename == f"{basename}.preview": if _basename == f"{basename}.preview":
result.append(filename) result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result return result
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):