mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Use pydantic.
This commit is contained in:
parent
db1ce51fdf
commit
7461e8eb00
@ -8,35 +8,25 @@ import re
|
|||||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class DownloadStatusType(str, Enum):
|
||||||
class DownloadStatusType(Enum):
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
IN_PROGRESS = "in_progress"
|
IN_PROGRESS = "in_progress"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
@dataclass
|
class DownloadModelStatus(BaseModel):
|
||||||
class DownloadModelStatus():
|
status: DownloadStatusType
|
||||||
status: str
|
progress_percentage: float = Field(ge=0, le=100)
|
||||||
progress_percentage: float
|
|
||||||
message: str
|
message: str
|
||||||
already_existed: bool = False
|
already_existed: bool = False
|
||||||
|
|
||||||
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
|
class Config:
|
||||||
self.status = status.value # Store the string value of the Enum
|
use_enum_values = True
|
||||||
self.progress_percentage = progress_percentage
|
|
||||||
self.message = message
|
|
||||||
self.already_existed = already_existed
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return self.model_dump()
|
||||||
"status": self.status,
|
|
||||||
"progress_percentage": self.progress_percentage,
|
|
||||||
"message": self.message,
|
|
||||||
"already_existed": self.already_existed
|
|
||||||
}
|
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -65,10 +55,10 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
"""
|
"""
|
||||||
if not validate_model_subdirectory(model_sub_directory):
|
if not validate_model_subdirectory(model_sub_directory):
|
||||||
return DownloadModelStatus(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
status=DownloadStatusType.ERROR,
|
||||||
0,
|
progress_percentage=0,
|
||||||
"Invalid model subdirectory",
|
message="Invalid model subdirectory",
|
||||||
False
|
already_existed=False
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
||||||
@ -77,16 +67,25 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
return existing_file
|
return existing_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
status = DownloadModelStatus(status=DownloadStatusType.PENDING,
|
||||||
|
progress_percentage=0,
|
||||||
|
message=f"Starting download of {model_name}",
|
||||||
|
already_existed=False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
response = await model_download_request(model_url)
|
response = await model_download_request(model_url)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(status=DownloadStatusType.ERROR,
|
||||||
|
progress_percentage= 0,
|
||||||
|
message=error_message,
|
||||||
|
already_existed= False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
return DownloadModelStatus(status=DownloadStatusType.ERROR,
|
||||||
|
progress_percentage=0,
|
||||||
|
message= error_message,
|
||||||
|
already_existed=False)
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
||||||
|
|
||||||
@ -107,7 +106,11 @@ async def check_file_exists(file_path: str,
|
|||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
relative_path: str) -> Optional[DownloadModelStatus]:
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
status = DownloadModelStatus(
|
||||||
|
status=DownloadStatusType.COMPLETED,
|
||||||
|
progress_percentage=100,
|
||||||
|
message= f"{model_name} already exists",
|
||||||
|
already_existed=True)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return status
|
return status
|
||||||
return None
|
return None
|
||||||
@ -127,7 +130,10 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
async def update_progress():
|
async def update_progress():
|
||||||
nonlocal last_update_time
|
nonlocal last_update_time
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
status = DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS,
|
||||||
|
progress_percentage=progress,
|
||||||
|
message=f"Downloading {model_name}",
|
||||||
|
already_existed=False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
|
|
||||||
@ -147,7 +153,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
status = DownloadModelStatus(
|
||||||
|
status=DownloadStatusType.COMPLETED,
|
||||||
|
progress_percentage=100,
|
||||||
|
message=f"Successfully downloaded {model_name}",
|
||||||
|
already_existed=False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
return status
|
return status
|
||||||
@ -161,7 +171,10 @@ async def handle_download_error(e: Exception,
|
|||||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
||||||
relative_path: str) -> DownloadModelStatus:
|
relative_path: str) -> DownloadModelStatus:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(status=DownloadStatusType.ERROR,
|
||||||
|
progress_percentage=0,
|
||||||
|
message=error_message,
|
||||||
|
already_existed=False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ Pillow
|
|||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
psutil
|
psutil
|
||||||
|
pydantic~=2.8
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
|
@ -84,13 +84,19 @@ async def test_download_model_success():
|
|||||||
# Check initial call
|
# Check initial call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.bin',
|
'checkpoints/model.bin',
|
||||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False)
|
DownloadModelStatus(status=DownloadStatusType.PENDING,
|
||||||
|
progress_percentage= 0,
|
||||||
|
message="Starting download of model.bin",
|
||||||
|
already_existed= False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check final call
|
# Check final call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.bin',
|
'checkpoints/model.bin',
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False)
|
DownloadModelStatus(status=DownloadStatusType.COMPLETED,
|
||||||
|
progress_percentage=100,
|
||||||
|
message="Successfully downloaded model.bin",
|
||||||
|
already_existed= False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify file writing
|
# Verify file writing
|
||||||
@ -204,7 +210,10 @@ async def test_check_file_exists_when_file_exists(tmp_path):
|
|||||||
|
|
||||||
mock_callback.assert_called_once_with(
|
mock_callback.assert_called_once_with(
|
||||||
"test/existing_model.bin",
|
"test/existing_model.bin",
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True)
|
DownloadModelStatus(status=DownloadStatusType.COMPLETED,
|
||||||
|
progress_percentage=100,
|
||||||
|
message="existing_model.bin already exists",
|
||||||
|
already_existed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -237,7 +246,7 @@ async def test_track_download_progress_no_content_length():
|
|||||||
# Check that progress was reported even without knowing the total size
|
# Check that progress was reported even without knowing the total size
|
||||||
mock_callback.assert_any_call(
|
mock_callback.assert_any_call(
|
||||||
'models/model.bin',
|
'models/model.bin',
|
||||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False)
|
DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS, progress_percentage= 0, message="Downloading model.bin", already_existed=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
Reference in New Issue
Block a user