From 7461e8eb0073add315c65c6f5e361f0891bffc7d Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Thu, 8 Aug 2024 13:06:26 -0700 Subject: [PATCH] Use pydantic. --- model_filemanager/download_models.py | 73 +++++++++++-------- requirements.txt | 1 + .../download_models_test.py | 17 ++++- 3 files changed, 57 insertions(+), 34 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 50302619..0a59f001 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -8,35 +8,25 @@ import re from typing import Callable, Any, Optional, Awaitable, Dict from enum import Enum import time -from dataclasses import dataclass +from pydantic import BaseModel, Field - -class DownloadStatusType(Enum): +class DownloadStatusType(str, Enum): PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" ERROR = "error" -@dataclass -class DownloadModelStatus(): - status: str - progress_percentage: float +class DownloadModelStatus(BaseModel): + status: DownloadStatusType + progress_percentage: float = Field(ge=0, le=100) message: str already_existed: bool = False - def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool): - self.status = status.value # Store the string value of the Enum - self.progress_percentage = progress_percentage - self.message = message - self.already_existed = already_existed - + class Config: + use_enum_values = True + def to_dict(self) -> Dict[str, Any]: - return { - "status": self.status, - "progress_percentage": self.progress_percentage, - "message": self.message, - "already_existed": self.already_existed - } + return self.model_dump() async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], model_name: str, @@ -65,10 +55,10 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht """ if not validate_model_subdirectory(model_sub_directory): return DownloadModelStatus( - DownloadStatusType.ERROR, - 0, - "Invalid model subdirectory", - False + status=DownloadStatusType.ERROR, + progress_percentage=0, + message="Invalid model subdirectory", + already_existed=False ) file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) @@ -77,16 +67,25 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht return existing_file try: - status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) + status = DownloadModelStatus(status=DownloadStatusType.PENDING, + progress_percentage=0, + message=f"Starting download of {model_name}", + already_existed=False) await progress_callback(relative_path, status) response = await model_download_request(model_url) if response.status != 200: error_message = f"Failed to download {model_name}. Status code: {response.status}" logging.error(error_message) - status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) + status = DownloadModelStatus(status=DownloadStatusType.ERROR, + progress_percentage= 0, + message=error_message, + already_existed= False) await progress_callback(relative_path, status) - return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) + return DownloadModelStatus(status=DownloadStatusType.ERROR, + progress_percentage=0, + message= error_message, + already_existed=False) return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) @@ -107,7 +106,11 @@ async def check_file_exists(file_path: str, progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelStatus]: if os.path.exists(file_path): - status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) + status = DownloadModelStatus( + status=DownloadStatusType.COMPLETED, + progress_percentage=100, + message= f"{model_name} already exists", + already_existed=True) await progress_callback(relative_path, status) return status return None @@ -127,7 +130,10 @@ async def track_download_progress(response: aiohttp.ClientResponse, async def update_progress(): nonlocal last_update_time progress = (downloaded / total_size) * 100 if total_size > 0 else 0 - status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) + status = DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS, + progress_percentage=progress, + message=f"Downloading {model_name}", + already_existed=False) await progress_callback(relative_path, status) last_update_time = time.time() @@ -147,7 +153,11 @@ async def track_download_progress(response: aiohttp.ClientResponse, await update_progress() logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") - status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) + status = DownloadModelStatus( + status=DownloadStatusType.COMPLETED, + progress_percentage=100, + message=f"Successfully downloaded {model_name}", + already_existed=False) await progress_callback(relative_path, status) return status @@ -161,7 +171,10 @@ async def handle_download_error(e: Exception, progress_callback: Callable[[str, DownloadModelStatus], Any], relative_path: str) -> DownloadModelStatus: error_message = f"Error downloading {model_name}: {str(e)}" - status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) + status = DownloadModelStatus(status=DownloadStatusType.ERROR, + progress_percentage=0, + message=error_message, + already_existed=False) await progress_callback(relative_path, status) return status diff --git a/requirements.txt b/requirements.txt index 4c2c0b2b..ce9db8c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ Pillow scipy tqdm psutil +pydantic~=2.8 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 26dd94d4..f90c09a1 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -84,13 +84,19 @@ async def test_download_model_success(): # Check initial call mock_progress_callback.assert_any_call( 'checkpoints/model.bin', - DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False) + DownloadModelStatus(status=DownloadStatusType.PENDING, + progress_percentage= 0, + message="Starting download of model.bin", + already_existed= False) ) # Check final call mock_progress_callback.assert_any_call( 'checkpoints/model.bin', - DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False) + DownloadModelStatus(status=DownloadStatusType.COMPLETED, + progress_percentage=100, + message="Successfully downloaded model.bin", + already_existed= False) ) # Verify file writing @@ -204,7 +210,10 @@ async def test_check_file_exists_when_file_exists(tmp_path): mock_callback.assert_called_once_with( "test/existing_model.bin", - DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True) + DownloadModelStatus(status=DownloadStatusType.COMPLETED, + progress_percentage=100, + message="existing_model.bin already exists", + already_existed=True) ) @pytest.mark.asyncio @@ -237,7 +246,7 @@ async def test_track_download_progress_no_content_length(): # Check that progress was reported even without knowing the total size mock_callback.assert_any_call( 'models/model.bin', - DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False) + DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS, progress_percentage= 0, message="Downloading model.bin", already_existed=False) ) @pytest.mark.asyncio