mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add model downloading endpoint.
This commit is contained in:
parent
b334605a66
commit
6976ccc5ca
2
model_filemanager/__init__.py
Normal file
2
model_filemanager/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# model_manager/__init__.py
|
||||
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType
|
78
model_filemanager/download_models.py
Normal file
78
model_filemanager/download_models.py
Normal file
@ -0,0 +1,78 @@
|
||||
import aiohttp
|
||||
import os
|
||||
from folder_paths import models_dir
|
||||
from typing import Callable, Any
|
||||
from enum import Enum
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
class DownloadStatusType(Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
@dataclass
|
||||
class DownloadStatus():
|
||||
status: DownloadStatusType
|
||||
progress_percentage: float
|
||||
message: str
|
||||
|
||||
@dataclass
|
||||
class DownloadModelResult():
|
||||
status: DownloadStatusType
|
||||
message: str
|
||||
already_existed: bool
|
||||
|
||||
async def download_model(session: aiohttp.ClientSession,
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_directory: str,
|
||||
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult:
|
||||
"""
|
||||
Asynchronously downloads a model file from a given URL to a specified directory.
|
||||
|
||||
If the file already exists, return success.
|
||||
Downloads the file in chunks and reports progress as a percentage through the callback function.
|
||||
"""
|
||||
|
||||
full_model_dir = os.path.join(models_dir, model_directory)
|
||||
os.makedirs(full_model_dir, exist_ok=True) # Ensure the directory exists.
|
||||
file_path = os.path.join(full_model_dir, model_name)
|
||||
relative_path = '/'.join([model_directory, model_name])
|
||||
if os.path.exists(file_path):
|
||||
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
|
||||
await progress_callback(relative_path, status)
|
||||
return {"status": DownloadStatusType.COMPLETED, "message": f"{model_name} already exists", "already_existed": True}
|
||||
try:
|
||||
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
|
||||
await progress_callback(relative_path, status)
|
||||
|
||||
async with session.get(model_url) as response:
|
||||
if response.status != 200:
|
||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||
await progress_callback(relative_path, status)
|
||||
return {"status": DownloadStatusType.ERROR, "message": f"Failed to download {model_name}. Status code: {response.status} ", "already_existed": False}
|
||||
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded = 0
|
||||
|
||||
with open(file_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
|
||||
await progress_callback(relative_path, status)
|
||||
|
||||
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
|
||||
await progress_callback(relative_path, status)
|
||||
|
||||
return {"status": DownloadStatusType.COMPLETED, "message": f"Successfully downloaded {model_name}", "already_existed": False}
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||
await progress_callback(relative_path, status)
|
||||
return {"status": DownloadStatusType.ERROR, "message": error_message, "already_existed": False}
|
29
server.py
29
server.py
@ -12,7 +12,6 @@ import json
|
||||
import glob
|
||||
import struct
|
||||
import ssl
|
||||
import hashlib
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from io import BytesIO
|
||||
@ -28,6 +27,7 @@ import comfy.model_management
|
||||
import node_helpers
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
from model_filemanager import download_model, DownloadStatus
|
||||
|
||||
|
||||
class BinaryEventTypes:
|
||||
@ -76,6 +76,8 @@ class PromptServer():
|
||||
self.prompt_queue = None
|
||||
self.loop = loop
|
||||
self.messages = asyncio.Queue()
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
self.number = 0
|
||||
|
||||
middlewares = [cache_control]
|
||||
@ -559,6 +561,28 @@ class PromptServer():
|
||||
self.prompt_queue.delete_history_item(id_to_delete)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/download")
|
||||
async def download_handler(request):
|
||||
async def report_progress(filename: str, status: DownloadStatus):
|
||||
await self.send_json(filename, {
|
||||
"progress_percentage": status.progress_percentage,
|
||||
"status": status.status,
|
||||
"message": status.message
|
||||
})
|
||||
|
||||
data = await request.json()
|
||||
url = data.get('url')
|
||||
model_directory = data.get('model_directory')
|
||||
model_filename = data.get('model_filename')
|
||||
|
||||
if not url or not model_directory or not model_filename:
|
||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||
|
||||
task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress))
|
||||
await task
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
@ -698,3 +722,6 @@ class PromptServer():
|
||||
logging.warning(traceback.format_exc())
|
||||
|
||||
return json_data
|
||||
|
||||
def close_session(self):
|
||||
self.client_session.close()
|
||||
|
0
tests-unit/prompt_server_test/__init__.py
Normal file
0
tests-unit/prompt_server_test/__init__.py
Normal file
68
tests-unit/prompt_server_test/download_models_test.py
Normal file
68
tests-unit/prompt_server_test/download_models_test.py
Normal file
@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
import aiohttp
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from model_filemanager import download_model, DownloadStatus, DownloadStatusType
|
||||
|
||||
|
||||
async def async_iterator(chunks):
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_success():
|
||||
# Create a temporary directory for testing
|
||||
model_directory = str(uuid.uuid4())
|
||||
|
||||
# Create a mock session
|
||||
session = AsyncMock(spec=aiohttp.ClientSession)
|
||||
|
||||
# Mock the response
|
||||
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.status = 200
|
||||
mock_response.headers = {'Content-Length': '100'}
|
||||
mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2'])
|
||||
|
||||
session.get.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
# Create a mock progress callback
|
||||
progress_callback = AsyncMock()
|
||||
|
||||
# Call the function
|
||||
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
||||
|
||||
# Assert the expected behavior
|
||||
assert result['status'] == DownloadStatusType.COMPLETED
|
||||
assert result['message'] == 'Successfully downloaded model.safetensors'
|
||||
assert result['already_existed'] is False
|
||||
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
||||
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors'))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_failure():
|
||||
# Create a temporary directory for testing
|
||||
model_directory = str(uuid.uuid4())
|
||||
|
||||
# Create a mock session
|
||||
session = AsyncMock(spec=aiohttp.ClientSession)
|
||||
|
||||
# Mock the response with an error status code
|
||||
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.status = 500
|
||||
session.get.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
# Create a mock progress callback
|
||||
progress_callback = AsyncMock()
|
||||
|
||||
# Call the function
|
||||
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
||||
print(result)
|
||||
|
||||
# Assert the expected behavior
|
||||
assert result['status'] == DownloadStatusType.ERROR
|
||||
assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500'
|
||||
assert result['already_existed'] is False
|
||||
|
||||
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
||||
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500'))
|
@ -1 +1,2 @@
|
||||
pytest>=7.8.0
|
||||
pytest-aiohttp
|
Loading…
Reference in New Issue
Block a user