mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Fixed.
This commit is contained in:
parent
c2cd09540d
commit
3881f03545
@ -1,2 +1,2 @@
|
||||
# model_manager/__init__.py
|
||||
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType
|
||||
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress
|
||||
|
@ -1,7 +1,9 @@
|
||||
import aiohttp
|
||||
import os
|
||||
import traceback
|
||||
import logging
|
||||
from folder_paths import models_dir
|
||||
from typing import Callable, Any, Optional
|
||||
from typing import Callable, Any, Optional, Awaitable
|
||||
from enum import Enum
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@ -34,13 +36,13 @@ class DownloadModelResult():
|
||||
self.message = message
|
||||
self.already_existed = already_existed
|
||||
|
||||
async def download_model(session: aiohttp.ClientSession,
|
||||
async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_directory: str,
|
||||
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult:
|
||||
file_path, relative_path = create_model_path(model_name, model_directory)
|
||||
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
|
||||
|
||||
file_path, relative_path = create_model_path(model_name, model_directory, models_dir)
|
||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
||||
if existing_file:
|
||||
return existing_file
|
||||
@ -49,8 +51,7 @@ async def download_model(session: aiohttp.ClientSession,
|
||||
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
|
||||
await progress_callback(relative_path, status)
|
||||
|
||||
response = await session.get(model_url)
|
||||
|
||||
response = await make_request(model_url)
|
||||
if response.status != 200:
|
||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||
@ -60,17 +61,22 @@ async def download_model(session: aiohttp.ClientSession,
|
||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in track_download_progress: {e}")
|
||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||
|
||||
|
||||
def create_model_path(model_name: str, model_directory: str) -> tuple[str, str]:
|
||||
full_model_dir = os.path.join(models_dir, model_directory)
|
||||
|
||||
async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse:
|
||||
return await session.get(url)
|
||||
|
||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
||||
os.makedirs(full_model_dir, exist_ok=True)
|
||||
file_path = os.path.join(full_model_dir, model_name)
|
||||
relative_path = '/'.join([model_directory, model_name])
|
||||
return file_path, relative_path
|
||||
|
||||
async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> Optional[DownloadModelResult]:
|
||||
async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelResult]:
|
||||
if os.path.exists(file_path):
|
||||
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
|
||||
await progress_callback(relative_path, status)
|
||||
@ -78,7 +84,8 @@ async def check_file_exists(file_path: str, model_name: str, progress_callback:
|
||||
return None
|
||||
|
||||
|
||||
async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str, interval: float = 1.0) -> DownloadModelResult:
|
||||
async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str, interval: float = 1.0) -> DownloadModelResult:
|
||||
try:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded = 0
|
||||
last_update_time = time.time()
|
||||
@ -91,21 +98,29 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s
|
||||
last_update_time = time.time()
|
||||
|
||||
with open(file_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
chunk_iterator = response.content.iter_chunked(8192)
|
||||
while True:
|
||||
try:
|
||||
chunk = await chunk_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
# Check if it's time to update progress
|
||||
if time.time() - last_update_time >= interval:
|
||||
await update_progress()
|
||||
|
||||
# Ensure we send a final update
|
||||
await update_progress()
|
||||
|
||||
logging.info(f"Download completed. Total downloaded: {downloaded}")
|
||||
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
|
||||
await progress_callback(relative_path, status)
|
||||
|
||||
return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in track_download_progress: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||
|
||||
async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult:
|
||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||
|
14
server.py
14
server.py
@ -28,7 +28,7 @@ import node_helpers
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
from model_filemanager import download_model, DownloadStatus
|
||||
|
||||
from typing import Optional
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
@ -76,7 +76,7 @@ class PromptServer():
|
||||
self.prompt_queue = None
|
||||
self.loop = loop
|
||||
self.messages = asyncio.Queue()
|
||||
self.client_session = None
|
||||
self.client_session:Optional[aiohttp.ClientSession] = None
|
||||
self.number = 0
|
||||
|
||||
middlewares = [cache_control]
|
||||
@ -579,7 +579,12 @@ class PromptServer():
|
||||
if not url or not model_directory or not model_filename:
|
||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||
|
||||
task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress))
|
||||
session = self.client_session
|
||||
if session is None:
|
||||
logging.error("Client session is not initialized")
|
||||
return web.Response(status=500)
|
||||
|
||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress))
|
||||
await task
|
||||
|
||||
return web.Response(status=200)
|
||||
@ -726,6 +731,3 @@ class PromptServer():
|
||||
logging.warning(traceback.format_exc())
|
||||
|
||||
return json_data
|
||||
|
||||
def close_session(self):
|
||||
self.client_session.close()
|
||||
|
@ -1,68 +1,243 @@
|
||||
import pytest
|
||||
import aiohttp
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from model_filemanager import download_model, DownloadStatus, DownloadStatusType
|
||||
from aiohttp import ClientResponse
|
||||
import itertools
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
|
||||
|
||||
class AsyncIteratorMock:
|
||||
def __init__(self, seq):
|
||||
self.iter = iter(seq)
|
||||
|
||||
async def async_iterator(chunks):
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self.iter)
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
||||
class ContentMock:
|
||||
def __init__(self, chunks):
|
||||
self.chunks = chunks
|
||||
|
||||
def iter_chunked(self, chunk_size):
|
||||
return AsyncIteratorMock(self.chunks)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_success():
|
||||
# Create a temporary directory for testing
|
||||
model_directory = str(uuid.uuid4())
|
||||
|
||||
# Create a mock session
|
||||
session = AsyncMock(spec=aiohttp.ClientSession)
|
||||
|
||||
# Mock the response
|
||||
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
||||
# Mock dependencies
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.status = 200
|
||||
mock_response.headers = {'Content-Length': '100'}
|
||||
mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2'])
|
||||
mock_response.headers = {'Content-Length': '1000'}
|
||||
|
||||
session.get.return_value.__aenter__.return_value = mock_response
|
||||
# Create a mock for content that returns an async iterator directly
|
||||
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
|
||||
mock_response.content = ContentMock(chunks)
|
||||
|
||||
# Create a mock progress callback
|
||||
progress_callback = AsyncMock()
|
||||
mock_make_request = AsyncMock(return_value=mock_response)
|
||||
mock_progress_callback = MagicMock()
|
||||
|
||||
# Mock file operations
|
||||
mock_open = MagicMock()
|
||||
mock_file = MagicMock()
|
||||
mock_open.return_value.__enter__.return_value = mock_file
|
||||
time_values = itertools.count(0, 0.1)
|
||||
|
||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.bin', 'checkpoints/model.bin')), \
|
||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('builtins.open', mock_open), \
|
||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
'model.bin',
|
||||
'http://example.com/model.bin',
|
||||
'checkpoints',
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, DownloadModelResult)
|
||||
assert result.message == 'Successfully downloaded model.bin'
|
||||
assert result.status == 'completed'
|
||||
assert result.already_existed is False
|
||||
|
||||
# Check progress callback calls
|
||||
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
||||
|
||||
# Check initial call
|
||||
mock_progress_callback.assert_any_call(
|
||||
'checkpoints/model.bin',
|
||||
DownloadStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin")
|
||||
)
|
||||
|
||||
# Check final call
|
||||
mock_progress_callback.assert_any_call(
|
||||
'checkpoints/model.bin',
|
||||
DownloadStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin")
|
||||
)
|
||||
|
||||
# Verify file writing
|
||||
mock_file.write.assert_any_call(b'a' * 500)
|
||||
mock_file.write.assert_any_call(b'b' * 300)
|
||||
mock_file.write.assert_any_call(b'c' * 200)
|
||||
|
||||
# Verify request was made
|
||||
mock_make_request.assert_called_once_with('http://example.com/model.bin')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_url_request_failure():
|
||||
# Mock dependencies
|
||||
mock_response = AsyncMock(spec=ClientResponse)
|
||||
mock_response.status = 404 # Simulate a "Not Found" error
|
||||
mock_get = AsyncMock(return_value=mock_response)
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
# Mock the create_model_path function
|
||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
||||
# Call the function
|
||||
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
||||
result = await download_model(
|
||||
mock_get,
|
||||
'model.safetensors',
|
||||
'http://example.com/model.safetensors',
|
||||
'mock_directory',
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the expected behavior
|
||||
assert result['status'] == DownloadStatusType.COMPLETED
|
||||
assert result['message'] == 'Successfully downloaded model.safetensors'
|
||||
assert result['already_existed'] is False
|
||||
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
||||
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors'))
|
||||
assert isinstance(result, DownloadModelResult)
|
||||
assert result.status == 'error'
|
||||
assert result.message == 'Failed to download model.safetensors. Status code: 404'
|
||||
assert result.already_existed is False
|
||||
|
||||
# Check that progress_callback was called with the correct arguments
|
||||
mock_progress_callback.assert_any_call(
|
||||
'mock_directory/model.safetensors',
|
||||
DownloadStatus(
|
||||
status=DownloadStatusType.PENDING,
|
||||
progress_percentage=0,
|
||||
message='Starting download of model.safetensors'
|
||||
)
|
||||
)
|
||||
mock_progress_callback.assert_called_with(
|
||||
'mock_directory/model.safetensors',
|
||||
DownloadStatus(
|
||||
status=DownloadStatusType.ERROR,
|
||||
progress_percentage=0,
|
||||
message='Failed to download model.safetensors. Status code: 404'
|
||||
)
|
||||
)
|
||||
|
||||
# Verify that the get method was called with the correct URL
|
||||
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
||||
|
||||
# For create_model_path function
|
||||
def test_create_model_path(tmp_path, monkeypatch):
|
||||
mock_models_dir = tmp_path / "models"
|
||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
||||
|
||||
model_name = "test_model.bin"
|
||||
model_directory = "test_dir"
|
||||
|
||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
||||
|
||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
||||
assert relative_path == f"{model_directory}/{model_name}"
|
||||
assert os.path.exists(os.path.dirname(file_path))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_failure():
|
||||
# Create a temporary directory for testing
|
||||
model_directory = str(uuid.uuid4())
|
||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||
file_path = tmp_path / "existing_model.bin"
|
||||
file_path.touch() # Create an empty file
|
||||
|
||||
# Create a mock session
|
||||
session = AsyncMock(spec=aiohttp.ClientSession)
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
# Mock the response with an error status code
|
||||
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.status = 500
|
||||
session.get.return_value.__aenter__.return_value = mock_response
|
||||
result = await check_file_exists(str(file_path), "existing_model.bin", mock_callback, "test/existing_model.bin")
|
||||
|
||||
# Create a mock progress callback
|
||||
progress_callback = AsyncMock()
|
||||
assert result is not None
|
||||
assert result.status == "completed"
|
||||
assert result.message == "existing_model.bin already exists"
|
||||
assert result.already_existed is True
|
||||
|
||||
# Call the function
|
||||
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
||||
print(result)
|
||||
mock_callback.assert_called_once_with(
|
||||
"test/existing_model.bin",
|
||||
DownloadStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists")
|
||||
)
|
||||
|
||||
# Assert the expected behavior
|
||||
assert result['status'] == DownloadStatusType.ERROR
|
||||
assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500'
|
||||
assert result['already_existed'] is False
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||
file_path = tmp_path / "non_existing_model.bin"
|
||||
|
||||
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
||||
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500'))
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
result = await check_file_exists(str(file_path), "non_existing_model.bin", mock_callback, "test/non_existing_model.bin")
|
||||
|
||||
assert result is None
|
||||
mock_callback.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_no_content_length():
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.headers = {} # No Content-Length header
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
mock_open = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch('builtins.open', mock_open):
|
||||
result = await track_download_progress(
|
||||
mock_response, '/mock/path/model.bin', 'model.bin',
|
||||
mock_callback, 'models/model.bin', interval=0.1
|
||||
)
|
||||
|
||||
assert result.status == "completed"
|
||||
# Check that progress was reported even without knowing the total size
|
||||
mock_callback.assert_any_call(
|
||||
'models/model.bin',
|
||||
DownloadStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin")
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_interval():
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.headers = {'Content-Length': '1000'}
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
mock_open = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create a mock time function that returns incremental float values
|
||||
mock_time = MagicMock()
|
||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||
|
||||
with patch('builtins.open', mock_open), \
|
||||
patch('time.time', mock_time):
|
||||
await track_download_progress(
|
||||
mock_response, '/mock/path/model.bin', 'model.bin',
|
||||
mock_callback, 'models/model.bin', interval=1.0
|
||||
)
|
||||
|
||||
# Print out the actual call count and the arguments of each call for debugging
|
||||
print(f"mock_callback was called {mock_callback.call_count} times")
|
||||
for i, call in enumerate(mock_callback.call_args_list):
|
||||
args, kwargs = call
|
||||
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
||||
|
||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||
|
||||
# Verify the first and last calls
|
||||
first_call = mock_callback.call_args_list[0]
|
||||
assert first_call[0][1].status == "in_progress"
|
||||
# Allow for some initial progress, but it should be less than 50%
|
||||
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
|
||||
|
||||
last_call = mock_callback.call_args_list[-1]
|
||||
assert last_call[0][1].status == "completed"
|
||||
assert last_call[0][1].progress_percentage == 100
|
@ -1,2 +1,3 @@
|
||||
pytest>=7.8.0
|
||||
pytest-aiohttp
|
||||
pytest-asyncio
|
Loading…
Reference in New Issue
Block a user