From c36a55956422e9ed3de3735d1905c006d44e0ad1 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Wed, 7 Aug 2024 12:17:29 -0700 Subject: [PATCH] Use async mock. --- model_filemanager/download_models.py | 2 +- .../download_models_test.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 5173a2a6..ed0c5a52 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -112,7 +112,7 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s await update_progress() - logging.info(f"Download completed. Total downloaded: {downloaded}") + logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}") await progress_callback(relative_path, status) diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 9f0ac45b..26142289 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -7,38 +7,51 @@ from unittest.mock import AsyncMock, patch, MagicMock from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType class AsyncIteratorMock: + """ + A mock class that simulates an asynchronous iterator. + This is used to mimic the behavior of aiohttp's content iterator. + """ def __init__(self, seq): + # Convert the input sequence into an iterator self.iter = iter(seq) def __aiter__(self): + # This method is called when 'async for' is used return self async def __anext__(self): + # This method is called for each iteration in an 'async for' loop try: return next(self.iter) except StopIteration: + # This is the asynchronous equivalent of StopIteration raise StopAsyncIteration class ContentMock: + """ + A mock class that simulates the content attribute of an aiohttp ClientResponse. + This class provides the iter_chunked method which returns an async iterator of chunks. + """ def __init__(self, chunks): + # Store the chunks that will be returned by the iterator self.chunks = chunks def iter_chunked(self, chunk_size): + # This method mimics aiohttp's content.iter_chunked() + # For simplicity in testing, we ignore chunk_size and just return our predefined chunks return AsyncIteratorMock(self.chunks) @pytest.mark.asyncio async def test_download_model_success(): - # Mock dependencies mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.status = 200 mock_response.headers = {'Content-Length': '1000'} - # Create a mock for content that returns an async iterator directly chunks = [b'a' * 500, b'b' * 300, b'c' * 200] mock_response.content = ContentMock(chunks) mock_make_request = AsyncMock(return_value=mock_response) - mock_progress_callback = MagicMock() + mock_progress_callback = AsyncMock() # Mock file operations mock_open = MagicMock()