Use async mock.

This commit is contained in:
Robin Huang 2024-08-07 12:17:29 -07:00
parent 3881f03545
commit c36a559564
2 changed files with 17 additions and 4 deletions

View File

@ -112,7 +112,7 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s
await update_progress()
logging.info(f"Download completed. Total downloaded: {downloaded}")
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
await progress_callback(relative_path, status)

View File

@ -7,38 +7,51 @@ from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
class AsyncIteratorMock:
"""
A mock class that simulates an asynchronous iterator.
This is used to mimic the behavior of aiohttp's content iterator.
"""
def __init__(self, seq):
# Convert the input sequence into an iterator
self.iter = iter(seq)
def __aiter__(self):
# This method is called when 'async for' is used
return self
async def __anext__(self):
# This method is called for each iteration in an 'async for' loop
try:
return next(self.iter)
except StopIteration:
# This is the asynchronous equivalent of StopIteration
raise StopAsyncIteration
class ContentMock:
"""
A mock class that simulates the content attribute of an aiohttp ClientResponse.
This class provides the iter_chunked method which returns an async iterator of chunks.
"""
def __init__(self, chunks):
# Store the chunks that will be returned by the iterator
self.chunks = chunks
def iter_chunked(self, chunk_size):
# This method mimics aiohttp's content.iter_chunked()
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio
async def test_download_model_success():
# Mock dependencies
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'}
# Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
mock_response.content = ContentMock(chunks)
mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = MagicMock()
mock_progress_callback = AsyncMock()
# Mock file operations
mock_open = MagicMock()