This commit is contained in:
Robin Huang 2024-08-07 12:12:01 -07:00
parent c2cd09540d
commit 3881f03545
5 changed files with 288 additions and 95 deletions

View File

@ -1,2 +1,2 @@
# model_manager/__init__.py # model_manager/__init__.py
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress

View File

@ -1,7 +1,9 @@
import aiohttp import aiohttp
import os import os
import traceback
import logging
from folder_paths import models_dir from folder_paths import models_dir
from typing import Callable, Any, Optional from typing import Callable, Any, Optional, Awaitable
from enum import Enum from enum import Enum
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@ -34,13 +36,13 @@ class DownloadModelResult():
self.message = message self.message = message
self.already_existed = already_existed self.already_existed = already_existed
async def download_model(session: aiohttp.ClientSession, async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
model_url: str, model_url: str,
model_directory: str, model_directory: str,
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult: progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
file_path, relative_path = create_model_path(model_name, model_directory)
file_path, relative_path = create_model_path(model_name, model_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if existing_file: if existing_file:
return existing_file return existing_file
@ -49,8 +51,7 @@ async def download_model(session: aiohttp.ClientSession,
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}") status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
response = await session.get(model_url) response = await make_request(model_url)
if response.status != 200: if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}" error_message = f"Failed to download {model_name}. Status code: {response.status}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
@ -60,17 +61,22 @@ async def download_model(session: aiohttp.ClientSession,
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path) return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
except Exception as e: except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback, relative_path)
def create_model_path(model_name: str, model_directory: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_dir, model_directory) async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse:
return await session.get(url)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True) os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name) file_path = os.path.join(full_model_dir, model_name)
relative_path = '/'.join([model_directory, model_name]) relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path return file_path, relative_path
async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> Optional[DownloadModelResult]: async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelResult]:
if os.path.exists(file_path): if os.path.exists(file_path):
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists") status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
@ -78,34 +84,43 @@ async def check_file_exists(file_path: str, model_name: str, progress_callback:
return None return None
async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str, interval: float = 1.0) -> DownloadModelResult: async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str, interval: float = 1.0) -> DownloadModelResult:
total_size = int(response.headers.get('Content-Length', 0)) try:
downloaded = 0 total_size = int(response.headers.get('Content-Length', 0))
last_update_time = time.time() downloaded = 0
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
await progress_callback(relative_path, status)
last_update_time = time.time() last_update_time = time.time()
with open(file_path, 'wb') as f: async def update_progress():
async for chunk in response.content.iter_chunked(8192): nonlocal last_update_time
f.write(chunk) progress = (downloaded / total_size) * 100 if total_size > 0 else 0
downloaded += len(chunk) status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
await progress_callback(relative_path, status)
last_update_time = time.time()
# Check if it's time to update progress with open(file_path, 'wb') as f:
if time.time() - last_update_time >= interval: chunk_iterator = response.content.iter_chunked(8192)
await update_progress() while True:
try:
chunk = await chunk_iterator.__anext__()
except StopAsyncIteration:
break
f.write(chunk)
downloaded += len(chunk)
# Ensure we send a final update if time.time() - last_update_time >= interval:
await update_progress() await update_progress()
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}") await update_progress()
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False) logging.info(f"Download completed. Total downloaded: {downloaded}")
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False)
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path)
async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult: async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult:
error_message = f"Error downloading {model_name}: {str(e)}" error_message = f"Error downloading {model_name}: {str(e)}"

View File

@ -28,7 +28,7 @@ import node_helpers
from app.frontend_management import FrontendManager from app.frontend_management import FrontendManager
from app.user_manager import UserManager from app.user_manager import UserManager
from model_filemanager import download_model, DownloadStatus from model_filemanager import download_model, DownloadStatus
from typing import Optional
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
@ -76,7 +76,7 @@ class PromptServer():
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.client_session = None self.client_session:Optional[aiohttp.ClientSession] = None
self.number = 0 self.number = 0
middlewares = [cache_control] middlewares = [cache_control]
@ -579,7 +579,12 @@ class PromptServer():
if not url or not model_directory or not model_filename: if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress)) session = self.client_session
if session is None:
logging.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress))
await task await task
return web.Response(status=200) return web.Response(status=200)
@ -726,6 +731,3 @@ class PromptServer():
logging.warning(traceback.format_exc()) logging.warning(traceback.format_exc())
return json_data return json_data
def close_session(self):
self.client_session.close()

View File

@ -1,68 +1,243 @@
import pytest import pytest
import aiohttp import aiohttp
import uuid from aiohttp import ClientResponse
from unittest.mock import AsyncMock, MagicMock import itertools
from model_filemanager import download_model, DownloadStatus, DownloadStatusType import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
class AsyncIteratorMock:
def __init__(self, seq):
self.iter = iter(seq)
async def async_iterator(chunks): def __aiter__(self):
for chunk in chunks: return self
yield chunk
async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration
class ContentMock:
def __init__(self, chunks):
self.chunks = chunks
def iter_chunked(self, chunk_size):
return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_success(): async def test_download_model_success():
# Create a temporary directory for testing # Mock dependencies
model_directory = str(uuid.uuid4()) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
# Create a mock session
session = AsyncMock(spec=aiohttp.ClientSession)
# Mock the response
mock_response = MagicMock(spec=aiohttp.ClientResponse)
mock_response.status = 200 mock_response.status = 200
mock_response.headers = {'Content-Length': '100'} mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2'])
session.get.return_value.__aenter__.return_value = mock_response # Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
mock_response.content = ContentMock(chunks)
# Create a mock progress callback mock_make_request = AsyncMock(return_value=mock_response)
progress_callback = AsyncMock() mock_progress_callback = MagicMock()
# Call the function # Mock file operations
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback) mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.bin', 'checkpoints/model.bin')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
mock_make_request,
'model.bin',
'http://example.com/model.bin',
'checkpoints',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelResult)
assert result.message == 'Successfully downloaded model.bin'
assert result.status == 'completed'
assert result.already_existed is False
# Check progress callback calls
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
# Check initial call
mock_progress_callback.assert_any_call(
'checkpoints/model.bin',
DownloadStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin")
)
# Check final call
mock_progress_callback.assert_any_call(
'checkpoints/model.bin',
DownloadStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin")
)
# Verify file writing
mock_file.write.assert_any_call(b'a' * 500)
mock_file.write.assert_any_call(b'b' * 300)
mock_file.write.assert_any_call(b'c' * 200)
# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.bin')
@pytest.mark.asyncio
async def test_download_model_url_request_failure():
# Mock dependencies
mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
# Mock the check_file_exists function to return None (file doesn't exist)
with patch('model_filemanager.check_file_exists', return_value=None):
# Call the function
result = await download_model(
mock_get,
'model.safetensors',
'http://example.com/model.safetensors',
'mock_directory',
mock_progress_callback
)
# Assert the expected behavior # Assert the expected behavior
assert result['status'] == DownloadStatusType.COMPLETED assert isinstance(result, DownloadModelResult)
assert result['message'] == 'Successfully downloaded model.safetensors' assert result.status == 'error'
assert result['already_existed'] is False assert result.message == 'Failed to download model.safetensors. Status code: 404'
relative_path = '/'.join([model_directory, 'model.safetensors']) assert result.already_existed is False
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors'))
# Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call(
'mock_directory/model.safetensors',
DownloadStatus(
status=DownloadStatusType.PENDING,
progress_percentage=0,
message='Starting download of model.safetensors'
)
)
mock_progress_callback.assert_called_with(
'mock_directory/model.safetensors',
DownloadStatus(
status=DownloadStatusType.ERROR,
progress_percentage=0,
message='Failed to download model.safetensors. Status code: 404'
)
)
# Verify that the get method was called with the correct URL
mock_get.assert_called_once_with('http://example.com/model.safetensors')
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
model_name = "test_model.bin"
model_directory = "test_dir"
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
assert os.path.exists(os.path.dirname(file_path))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_failure(): async def test_check_file_exists_when_file_exists(tmp_path):
# Create a temporary directory for testing file_path = tmp_path / "existing_model.bin"
model_directory = str(uuid.uuid4()) file_path.touch() # Create an empty file
# Create a mock session mock_callback = AsyncMock()
session = AsyncMock(spec=aiohttp.ClientSession)
# Mock the response with an error status code result = await check_file_exists(str(file_path), "existing_model.bin", mock_callback, "test/existing_model.bin")
mock_response = MagicMock(spec=aiohttp.ClientResponse)
mock_response.status = 500
session.get.return_value.__aenter__.return_value = mock_response
# Create a mock progress callback assert result is not None
progress_callback = AsyncMock() assert result.status == "completed"
assert result.message == "existing_model.bin already exists"
assert result.already_existed is True
# Call the function mock_callback.assert_called_once_with(
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback) "test/existing_model.bin",
print(result) DownloadStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists")
)
# Assert the expected behavior @pytest.mark.asyncio
assert result['status'] == DownloadStatusType.ERROR async def test_check_file_exists_when_file_does_not_exist(tmp_path):
assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500' file_path = tmp_path / "non_existing_model.bin"
assert result['already_existed'] is False
relative_path = '/'.join([model_directory, 'model.safetensors']) mock_callback = AsyncMock()
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500'))
result = await check_file_exists(str(file_path), "non_existing_model.bin", mock_callback, "test/non_existing_model.bin")
assert result is None
mock_callback.assert_not_called()
@pytest.mark.asyncio
async def test_track_download_progress_no_content_length():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('builtins.open', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.bin', 'model.bin',
mock_callback, 'models/model.bin', interval=0.1
)
assert result.status == "completed"
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'models/model.bin',
DownloadStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin")
)
@pytest.mark.asyncio
async def test_track_download_progress_interval():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
# Create a mock time function that returns incremental float values
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.bin', 'model.bin',
mock_callback, 'models/model.bin', interval=1.0
)
# Print out the actual call count and the arguments of each call for debugging
print(f"mock_callback was called {mock_callback.call_count} times")
for i, call in enumerate(mock_callback.call_args_list):
args, kwargs = call
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
# Verify the first and last calls
first_call = mock_callback.call_args_list[0]
assert first_call[0][1].status == "in_progress"
# Allow for some initial progress, but it should be less than 50%
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100

View File

@ -1,2 +1,3 @@
pytest>=7.8.0 pytest>=7.8.0
pytest-aiohttp pytest-aiohttp
pytest-asyncio