mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Add unit test for safetensors embedded image previews
This commit is contained in:
parent
9f6401d8dd
commit
3b0a3cb6f9
62
tests-unit/app_test/model_manager_test.py
Normal file
62
tests-unit/app_test/model_manager_test.py
Normal file
@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
import base64
|
||||
import json
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.model_manager import ModelFileManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager():
|
||||
return ModelFileManager()
|
||||
|
||||
@pytest.fixture
|
||||
def app(model_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
model_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
||||
img = Image.new('RGB', (100, 100), 'white')
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr.seek(0)
|
||||
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
|
||||
|
||||
safetensors_file = tmp_path / "test_model.safetensors"
|
||||
header_bytes = json.dumps({
|
||||
"__metadata__": {
|
||||
"ssmd_cover_images": json.dumps([img_b64])
|
||||
}
|
||||
}).encode('utf-8')
|
||||
length_bytes = struct.pack('<Q', len(header_bytes))
|
||||
with open(safetensors_file, 'wb') as f:
|
||||
f.write(length_bytes)
|
||||
f.write(header_bytes)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
|
||||
|
||||
# Verify response
|
||||
assert response.status == 200
|
||||
assert response.content_type == 'image/webp'
|
||||
|
||||
# Verify the response contains valid image data
|
||||
img_bytes = BytesIO(await response.read())
|
||||
img = Image.open(img_bytes)
|
||||
assert img.format
|
||||
assert img.format.lower() == 'webp'
|
||||
|
||||
# Clean up
|
||||
img.close()
|
Loading…
Reference in New Issue
Block a user