mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 18:35:17 +00:00
Break up large function.
This commit is contained in:
parent
8af203ecc6
commit
cce1ccf9c9
@ -1,9 +1,9 @@
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
from folder_paths import models_dir
|
from folder_paths import models_dir
|
||||||
from typing import Callable, Any
|
from typing import Callable, Any, Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
class DownloadStatusType(Enum):
|
class DownloadStatusType(Enum):
|
||||||
@ -14,65 +14,101 @@ class DownloadStatusType(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadStatus():
|
class DownloadStatus():
|
||||||
status: DownloadStatusType
|
status: str
|
||||||
progress_percentage: float
|
progress_percentage: float
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str):
|
||||||
|
self.status = status.value # Store the string value of the Enum
|
||||||
|
self.progress_percentage = progress_percentage
|
||||||
|
self.message = message
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadModelResult():
|
class DownloadModelResult():
|
||||||
status: DownloadStatusType
|
status: str
|
||||||
message: str
|
message: str
|
||||||
already_existed: bool
|
already_existed: bool
|
||||||
|
|
||||||
|
def __init__(self, status: DownloadStatusType, message: str, already_existed: bool):
|
||||||
|
self.status = status.value # Store the string value of the Enum
|
||||||
|
self.message = message
|
||||||
|
self.already_existed = already_existed
|
||||||
|
|
||||||
async def download_model(session: aiohttp.ClientSession,
|
async def download_model(session: aiohttp.ClientSession,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_directory: str,
|
model_directory: str,
|
||||||
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult:
|
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult:
|
||||||
"""
|
file_path, relative_path = create_model_path(model_name, model_directory)
|
||||||
Asynchronously downloads a model file from a given URL to a specified directory.
|
|
||||||
|
|
||||||
If the file already exists, return success.
|
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
||||||
Downloads the file in chunks and reports progress as a percentage through the callback function.
|
if existing_file:
|
||||||
"""
|
return existing_file
|
||||||
|
|
||||||
full_model_dir = os.path.join(models_dir, model_directory)
|
|
||||||
os.makedirs(full_model_dir, exist_ok=True) # Ensure the directory exists.
|
|
||||||
file_path = os.path.join(full_model_dir, model_name)
|
|
||||||
relative_path = '/'.join([model_directory, model_name])
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
|
|
||||||
await progress_callback(relative_path, status)
|
|
||||||
return {"status": DownloadStatusType.COMPLETED, "message": f"{model_name} already exists", "already_existed": True}
|
|
||||||
try:
|
try:
|
||||||
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
|
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
async with session.get(model_url) as response:
|
response = await session.get(model_url)
|
||||||
if response.status != 200:
|
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
|
||||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
|
||||||
await progress_callback(relative_path, status)
|
|
||||||
return {"status": DownloadStatusType.ERROR, "message": f"Failed to download {model_name}. Status code: {response.status} ", "already_existed": False}
|
|
||||||
|
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
if response.status != 200:
|
||||||
downloaded = 0
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
|
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
||||||
|
|
||||||
with open(file_path, 'wb') as f:
|
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
|
||||||
async for chunk in response.content.iter_chunked(8192):
|
|
||||||
f.write(chunk)
|
|
||||||
downloaded += len(chunk)
|
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
|
||||||
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
|
|
||||||
await progress_callback(relative_path, status)
|
|
||||||
|
|
||||||
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
|
|
||||||
await progress_callback(relative_path, status)
|
|
||||||
|
|
||||||
return {"status": DownloadStatusType.COMPLETED, "message": f"Successfully downloaded {model_name}", "already_existed": False}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
|
||||||
|
|
||||||
|
def create_model_path(model_name: str, model_directory: str) -> tuple[str, str]:
|
||||||
|
full_model_dir = os.path.join(models_dir, model_directory)
|
||||||
|
os.makedirs(full_model_dir, exist_ok=True)
|
||||||
|
file_path = os.path.join(full_model_dir, model_name)
|
||||||
|
relative_path = '/'.join([model_directory, model_name])
|
||||||
|
return file_path, relative_path
|
||||||
|
|
||||||
|
async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> Optional[DownloadModelResult]:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return {"status": DownloadStatusType.ERROR, "message": error_message, "already_existed": False}
|
return DownloadModelResult(DownloadStatusType.COMPLETED, f"{model_name} already exists", True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str, interval: float = 1.0) -> DownloadModelResult:
|
||||||
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
|
downloaded = 0
|
||||||
|
last_update_time = time.time()
|
||||||
|
|
||||||
|
async def update_progress():
|
||||||
|
nonlocal last_update_time
|
||||||
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
|
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
last_update_time = time.time()
|
||||||
|
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
async for chunk in response.content.iter_chunked(8192):
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
|
||||||
|
# Check if it's time to update progress
|
||||||
|
if time.time() - last_update_time >= interval:
|
||||||
|
await update_progress()
|
||||||
|
|
||||||
|
# Ensure we send a final update
|
||||||
|
await update_progress()
|
||||||
|
|
||||||
|
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
|
return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False)
|
||||||
|
|
||||||
|
async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult:
|
||||||
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
|
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
Loading…
Reference in New Issue
Block a user