Use async mock.

This commit is contained in:
Robin Huang 2024-08-07 12:17:29 -07:00
parent 3881f03545
commit c36a559564
2 changed files with 17 additions and 4 deletions

View File

@ -112,7 +112,7 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s
await update_progress() await update_progress()
logging.info(f"Download completed. Total downloaded: {downloaded}") logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}") status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
await progress_callback(relative_path, status) await progress_callback(relative_path, status)

View File

@ -7,38 +7,51 @@ from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
class AsyncIteratorMock: class AsyncIteratorMock:
"""
A mock class that simulates an asynchronous iterator.
This is used to mimic the behavior of aiohttp's content iterator.
"""
def __init__(self, seq): def __init__(self, seq):
# Convert the input sequence into an iterator
self.iter = iter(seq) self.iter = iter(seq)
def __aiter__(self): def __aiter__(self):
# This method is called when 'async for' is used
return self return self
async def __anext__(self): async def __anext__(self):
# This method is called for each iteration in an 'async for' loop
try: try:
return next(self.iter) return next(self.iter)
except StopIteration: except StopIteration:
# This is the asynchronous equivalent of StopIteration
raise StopAsyncIteration raise StopAsyncIteration
class ContentMock: class ContentMock:
"""
A mock class that simulates the content attribute of an aiohttp ClientResponse.
This class provides the iter_chunked method which returns an async iterator of chunks.
"""
def __init__(self, chunks): def __init__(self, chunks):
# Store the chunks that will be returned by the iterator
self.chunks = chunks self.chunks = chunks
def iter_chunked(self, chunk_size): def iter_chunked(self, chunk_size):
# This method mimics aiohttp's content.iter_chunked()
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
return AsyncIteratorMock(self.chunks) return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_success(): async def test_download_model_success():
# Mock dependencies
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200 mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'} mock_response.headers = {'Content-Length': '1000'}
# Create a mock for content that returns an async iterator directly # Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200] chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
mock_response.content = ContentMock(chunks) mock_response.content = ContentMock(chunks)
mock_make_request = AsyncMock(return_value=mock_response) mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = MagicMock() mock_progress_callback = AsyncMock()
# Mock file operations # Mock file operations
mock_open = MagicMock() mock_open = MagicMock()