Break up large function.

This commit is contained in:
Robin Huang 2024-08-06 21:21:51 -07:00
parent 8af203ecc6
commit cce1ccf9c9

View File

@ -1,9 +1,9 @@
import aiohttp import aiohttp
import os import os
from folder_paths import models_dir from folder_paths import models_dir
from typing import Callable, Any from typing import Callable, Any, Optional
from enum import Enum from enum import Enum
import time
from dataclasses import dataclass from dataclasses import dataclass
class DownloadStatusType(Enum): class DownloadStatusType(Enum):
@ -14,65 +14,101 @@ class DownloadStatusType(Enum):
@dataclass @dataclass
class DownloadStatus(): class DownloadStatus():
status: DownloadStatusType status: str
progress_percentage: float progress_percentage: float
message: str message: str
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str):
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
@dataclass @dataclass
class DownloadModelResult(): class DownloadModelResult():
status: DownloadStatusType status: str
message: str message: str
already_existed: bool already_existed: bool
def __init__(self, status: DownloadStatusType, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.message = message
self.already_existed = already_existed
async def download_model(session: aiohttp.ClientSession, async def download_model(session: aiohttp.ClientSession,
model_name: str, model_name: str,
model_url: str, model_url: str,
model_directory: str, model_directory: str,
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult: progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult:
""" file_path, relative_path = create_model_path(model_name, model_directory)
Asynchronously downloads a model file from a given URL to a specified directory.
If the file already exists, return success. existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
Downloads the file in chunks and reports progress as a percentage through the callback function. if existing_file:
""" return existing_file
full_model_dir = os.path.join(models_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True) # Ensure the directory exists.
file_path = os.path.join(full_model_dir, model_name)
relative_path = '/'.join([model_directory, model_name])
if os.path.exists(file_path):
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.COMPLETED, "message": f"{model_name} already exists", "already_existed": True}
try: try:
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}") status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
async with session.get(model_url) as response: response = await session.get(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.ERROR, "message": f"Failed to download {model_name}. Status code: {response.status} ", "already_existed": False}
total_size = int(response.headers.get('Content-Length', 0)) if response.status != 200:
downloaded = 0 error_message = f"Failed to download {model_name}. Status code: {response.status}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
with open(file_path, 'wb') as f: return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
async for chunk in response.content.iter_chunked(8192):
f.write(chunk)
downloaded += len(chunk)
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
await progress_callback(relative_path, status)
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.COMPLETED, "message": f"Successfully downloaded {model_name}", "already_existed": False}
except Exception as e: except Exception as e:
error_message = f"Error downloading {model_name}: {str(e)}" return await handle_download_error(e, model_name, progress_callback, relative_path)
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
def create_model_path(model_name: str, model_directory: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name)
relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path
async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> Optional[DownloadModelResult]:
if os.path.exists(file_path):
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return {"status": DownloadStatusType.ERROR, "message": error_message, "already_existed": False} return DownloadModelResult(DownloadStatusType.COMPLETED, f"{model_name} already exists", True)
return None
async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str, interval: float = 1.0) -> DownloadModelResult:
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
last_update_time = time.time()
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
await progress_callback(relative_path, status)
last_update_time = time.time()
with open(file_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
f.write(chunk)
downloaded += len(chunk)
# Check if it's time to update progress
if time.time() - last_update_time >= interval:
await update_progress()
# Ensure we send a final update
await update_progress()
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False)
async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)