mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Internal download API: Add proper validated directory input (#4981)
* add internal /folder_paths route returns a json maps of folder paths * (minor) format download_models.py * initial folder path input on download api * actually, require folder_path and clean up some code * partial tests update * fix & logging * also download to a tmp file not the live file to avoid compounding errors from network failure * update tests again * test tweaks * workaround the first tests blocker * fix file handling in tests * rewrite test for create_model_path * minor doc fix * avoid 'mock_directory' use temp dir to avoid accidental fs pollution from tests
This commit is contained in:
parent
479a427a48
commit
08c8968482
@ -1,2 +1,2 @@
|
|||||||
# model_manager/__init__.py
|
# model_manager/__init__.py
|
||||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
||||||
|
@ -3,7 +3,7 @@ import aiohttp
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from folder_paths import models_dir
|
from folder_paths import folder_names_and_paths, get_folder_paths
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -17,6 +17,7 @@ class DownloadStatusType(Enum):
|
|||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadModelStatus():
|
class DownloadModelStatus():
|
||||||
status: str
|
status: str
|
||||||
@ -29,7 +30,7 @@ class DownloadModelStatus():
|
|||||||
self.progress_percentage = progress_percentage
|
self.progress_percentage = progress_percentage
|
||||||
self.message = message
|
self.message = message
|
||||||
self.already_existed = already_existed
|
self.already_existed = already_existed
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"status": self.status,
|
"status": self.status,
|
||||||
@ -38,102 +39,112 @@ class DownloadModelStatus():
|
|||||||
"already_existed": self.already_existed
|
"already_existed": self.already_existed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_sub_directory: str,
|
model_directory: str,
|
||||||
|
folder_path: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
"""
|
"""
|
||||||
Download a model file from a given URL into the models directory.
|
Download a model file from a given URL into the models directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||||
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||||
model_name (str):
|
model_name (str):
|
||||||
The name of the model file to be downloaded. This will be the filename on disk.
|
The name of the model file to be downloaded. This will be the filename on disk.
|
||||||
model_url (str):
|
model_url (str):
|
||||||
The URL from which to download the model.
|
The URL from which to download the model.
|
||||||
model_sub_directory (str):
|
model_directory (str):
|
||||||
The subdirectory within the main models directory where the model
|
The subdirectory within the main models directory where the model
|
||||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||||
An asynchronous function to call with progress updates.
|
An asynchronous function to call with progress updates.
|
||||||
|
folder_path (str);
|
||||||
|
Path to which model folder should be used as the root.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DownloadModelStatus: The result of the download operation.
|
DownloadModelStatus: The result of the download operation.
|
||||||
"""
|
"""
|
||||||
if not validate_model_subdirectory(model_sub_directory):
|
|
||||||
return DownloadModelStatus(
|
|
||||||
DownloadStatusType.ERROR,
|
|
||||||
0,
|
|
||||||
"Invalid model subdirectory",
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
if not validate_filename(model_name):
|
if not validate_filename(model_name):
|
||||||
return DownloadModelStatus(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
DownloadStatusType.ERROR,
|
||||||
0,
|
0,
|
||||||
"Invalid model name",
|
"Invalid model name",
|
||||||
False
|
False
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
if not model_directory in folder_names_and_paths:
|
||||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not folder_path in get_folder_paths(model_directory):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = create_model_path(model_name, folder_path)
|
||||||
|
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
||||||
if existing_file:
|
if existing_file:
|
||||||
return existing_file
|
return existing_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logging.info(f"Downloading {model_name} from {model_url}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
response = await model_download_request(model_url)
|
response = await model_download_request(model_url)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in downloading model: {e}")
|
logging.error(f"Error in downloading model: {e}")
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
|
||||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
||||||
os.makedirs(full_model_dir, exist_ok=True)
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
file_path = os.path.join(full_model_dir, model_name)
|
file_path = os.path.join(folder_path, model_name)
|
||||||
|
|
||||||
# Ensure the resulting path is still within the base directory
|
# Ensure the resulting path is still within the base directory
|
||||||
abs_file_path = os.path.abspath(file_path)
|
abs_file_path = os.path.abspath(file_path)
|
||||||
abs_base_dir = os.path.abspath(str(models_base_dir))
|
abs_base_dir = os.path.abspath(folder_path)
|
||||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
relative_path = '/'.join([model_directory, model_name])
|
async def check_file_exists(file_path: str,
|
||||||
return file_path, relative_path
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
||||||
async def check_file_exists(file_path: str,
|
) -> Optional[DownloadModelStatus]:
|
||||||
model_name: str,
|
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
|
||||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
relative_path: str,
|
|
||||||
interval: float = 1.0) -> DownloadModelStatus:
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
try:
|
try:
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
@ -144,10 +155,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
nonlocal last_update_time
|
nonlocal last_update_time
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
|
|
||||||
with open(file_path, 'wb') as f:
|
temp_file_path = file_path + '.tmp'
|
||||||
|
with open(temp_file_path, 'wb') as f:
|
||||||
chunk_iterator = response.content.iter_chunked(8192)
|
chunk_iterator = response.content.iter_chunked(8192)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -156,58 +168,39 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
break
|
break
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
downloaded += len(chunk)
|
downloaded += len(chunk)
|
||||||
|
|
||||||
if time.time() - last_update_time >= interval:
|
if time.time() - last_update_time >= interval:
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
|
os.rename(temp_file_path, file_path)
|
||||||
|
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
return status
|
return status
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in track_download_progress: {e}")
|
logging.error(f"Error in track_download_progress: {e}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
async def handle_download_error(e: Exception,
|
|
||||||
model_name: str,
|
async def handle_download_error(e: Exception,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
model_name: str,
|
||||||
relative_path: str) -> DownloadModelStatus:
|
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
||||||
|
) -> DownloadModelStatus:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that the model subdirectory is safe to install into.
|
|
||||||
Must not contain relative paths, nested paths or special characters
|
|
||||||
other than underscores and hyphens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_subdirectory (str): The subdirectory for the specific model type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the subdirectory is safe, False otherwise.
|
|
||||||
"""
|
|
||||||
if len(model_subdirectory) > 50:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if '..' in model_subdirectory or '/' in model_subdirectory:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def validate_filename(filename: str)-> bool:
|
def validate_filename(filename: str)-> bool:
|
||||||
"""
|
"""
|
||||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): The filename to validate
|
filename (str): The filename to validate
|
||||||
|
|
||||||
|
@ -689,10 +689,11 @@ class PromptServer():
|
|||||||
data = await request.json()
|
data = await request.json()
|
||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
model_directory = data.get('model_directory')
|
model_directory = data.get('model_directory')
|
||||||
|
folder_path = data.get('folder_path')
|
||||||
model_filename = data.get('model_filename')
|
model_filename = data.get('model_filename')
|
||||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
if not url or not model_directory or not model_filename:
|
if not url or not model_directory or not model_filename or not folder_path:
|
||||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
|
|
||||||
session = self.client_session
|
session = self.client_session
|
||||||
@ -700,7 +701,7 @@ class PromptServer():
|
|||||||
logging.error("Client session is not initialized")
|
logging.error("Client session is not initialized")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return web.json_response(task.result().to_dict())
|
return web.json_response(task.result().to_dict())
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import tempfile
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import ClientResponse
|
from aiohttp import ClientResponse
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
yield tmpdirname
|
||||||
|
|
||||||
class AsyncIteratorMock:
|
class AsyncIteratorMock:
|
||||||
"""
|
"""
|
||||||
@ -42,7 +49,7 @@ class ContentMock:
|
|||||||
return AsyncIteratorMock(self.chunks)
|
return AsyncIteratorMock(self.chunks)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_success():
|
async def test_download_model_success(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
@ -53,15 +60,13 @@ async def test_download_model_success():
|
|||||||
mock_make_request = AsyncMock(return_value=mock_response)
|
mock_make_request = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
# Mock file operations
|
|
||||||
mock_open = MagicMock()
|
|
||||||
mock_file = MagicMock()
|
|
||||||
mock_open.return_value.__enter__.return_value = mock_file
|
|
||||||
time_values = itertools.count(0, 0.1)
|
time_values = itertools.count(0, 0.1)
|
||||||
|
|
||||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
|
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
||||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
patch('builtins.open', mock_open), \
|
patch('folder_paths.folder_names_and_paths', fake_paths), \
|
||||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
@ -69,6 +74,7 @@ async def test_download_model_success():
|
|||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'checkpoints',
|
'checkpoints',
|
||||||
|
temp_dir,
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -83,44 +89,48 @@ async def test_download_model_success():
|
|||||||
|
|
||||||
# Check initial call
|
# Check initial call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check final call
|
# Check final call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify file writing
|
mock_file_path = os.path.join(temp_dir, 'model.sft')
|
||||||
mock_file.write.assert_any_call(b'a' * 500)
|
assert os.path.exists(mock_file_path)
|
||||||
mock_file.write.assert_any_call(b'b' * 300)
|
with open(mock_file_path, 'rb') as mock_file:
|
||||||
mock_file.write.assert_any_call(b'c' * 200)
|
assert mock_file.read() == b''.join(chunks)
|
||||||
|
os.remove(mock_file_path)
|
||||||
|
|
||||||
# Verify request was made
|
# Verify request was made
|
||||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_url_request_failure():
|
async def test_download_model_url_request_failure(temp_dir):
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
mock_response = AsyncMock(spec=ClientResponse)
|
mock_response = AsyncMock(spec=ClientResponse)
|
||||||
mock_response.status = 404 # Simulate a "Not Found" error
|
mock_response.status = 404 # Simulate a "Not Found" error
|
||||||
mock_get = AsyncMock(return_value=mock_response)
|
mock_get = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
# Mock the create_model_path function
|
# Mock the create_model_path function
|
||||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
||||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
patch('folder_paths.folder_names_and_paths', fake_paths):
|
||||||
# Call the function
|
# Call the function
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_get,
|
mock_get,
|
||||||
'model.safetensors',
|
'model.safetensors',
|
||||||
'http://example.com/model.safetensors',
|
'http://example.com/model.safetensors',
|
||||||
'mock_directory',
|
'checkpoints',
|
||||||
mock_progress_callback
|
temp_dir,
|
||||||
)
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
# Assert the expected behavior
|
# Assert the expected behavior
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
# Check that progress_callback was called with the correct arguments
|
# Check that progress_callback was called with the correct arguments
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.PENDING,
|
status=DownloadStatusType.PENDING,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
mock_progress_callback.assert_called_with(
|
mock_progress_callback.assert_called_with(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.ERROR,
|
status=DownloadStatusType.ERROR,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@ -153,98 +163,125 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_invalid_model_subdirectory():
|
async def test_download_model_invalid_model_subdirectory():
|
||||||
|
|
||||||
mock_make_request = AsyncMock()
|
mock_make_request = AsyncMock()
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_make_request,
|
mock_make_request,
|
||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'../bad_path',
|
'../bad_path',
|
||||||
|
'../bad_path',
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert the result
|
# Assert the result
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
assert result.message == 'Invalid model subdirectory'
|
assert result.message.startswith('Invalid or unrecognized model directory')
|
||||||
assert result.status == 'error'
|
assert result.status == 'error'
|
||||||
assert result.already_existed is False
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_invalid_folder_path():
|
||||||
|
mock_make_request = AsyncMock()
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
result = await download_model(
|
||||||
|
mock_make_request,
|
||||||
|
'model.sft',
|
||||||
|
'http://example.com/model.sft',
|
||||||
|
'checkpoints',
|
||||||
|
'invalid_path',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.message.startswith("Invalid folder path")
|
||||||
|
assert result.status == 'error'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
# For create_model_path function
|
|
||||||
def test_create_model_path(tmp_path, monkeypatch):
|
def test_create_model_path(tmp_path, monkeypatch):
|
||||||
mock_models_dir = tmp_path / "models"
|
model_name = "model.safetensors"
|
||||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
folder_path = os.path.join(tmp_path, "mock_dir")
|
||||||
|
|
||||||
model_name = "test_model.sft"
|
file_path = create_model_path(model_name, folder_path)
|
||||||
model_directory = "test_dir"
|
|
||||||
|
assert file_path == os.path.join(folder_path, "model.safetensors")
|
||||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
|
||||||
|
|
||||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
|
||||||
assert relative_path == f"{model_directory}/{model_name}"
|
|
||||||
assert os.path.exists(os.path.dirname(file_path))
|
assert os.path.exists(os.path.dirname(file_path))
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("../path_traversal.safetensors", folder_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("/etc/some_root_path", folder_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||||
file_path = tmp_path / "existing_model.sft"
|
file_path = tmp_path / "existing_model.sft"
|
||||||
file_path.touch() # Create an empty file
|
file_path.touch() # Create an empty file
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
assert result.message == "existing_model.sft already exists"
|
assert result.message == "existing_model.sft already exists"
|
||||||
assert result.already_existed is True
|
assert result.already_existed is True
|
||||||
|
|
||||||
mock_callback.assert_called_once_with(
|
mock_callback.assert_called_once_with(
|
||||||
"test/existing_model.sft",
|
"existing_model.sft",
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||||
file_path = tmp_path / "non_existing_model.sft"
|
file_path = tmp_path / "non_existing_model.sft"
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
mock_callback.assert_not_called()
|
mock_callback.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_no_content_length():
|
async def test_track_download_progress_no_content_length(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {} # No Content-Length header
|
mock_response.headers = {} # No Content-Length header
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
chunks = [b'a' * 500, b'b' * 500]
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
|
||||||
|
|
||||||
with patch('builtins.open', mock_open):
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
result = await track_download_progress(
|
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
result = await track_download_progress(
|
||||||
mock_callback, 'models/model.sft', interval=0.1
|
mock_response, full_path, 'model.sft',
|
||||||
)
|
mock_callback, interval=0.1
|
||||||
|
)
|
||||||
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
assert os.path.exists(full_path)
|
||||||
|
with open(full_path, 'rb') as f:
|
||||||
|
assert f.read() == b''.join(chunks)
|
||||||
|
os.remove(full_path)
|
||||||
|
|
||||||
# Check that progress was reported even without knowing the total size
|
# Check that progress was reported even without knowing the total size
|
||||||
mock_callback.assert_any_call(
|
mock_callback.assert_any_call(
|
||||||
'models/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_interval():
|
async def test_track_download_progress_interval(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
chunks = [b'a' * 100] * 10
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
mock_open = MagicMock(return_value=MagicMock())
|
||||||
@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
|
|||||||
mock_time = MagicMock()
|
mock_time = MagicMock()
|
||||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||||
|
|
||||||
with patch('builtins.open', mock_open), \
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
patch('time.time', mock_time):
|
|
||||||
await track_download_progress(
|
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
|
||||||
mock_callback, 'models/model.sft', interval=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print out the actual call count and the arguments of each call for debugging
|
with patch('time.time', mock_time):
|
||||||
print(f"mock_callback was called {mock_callback.call_count} times")
|
await track_download_progress(
|
||||||
for i, call in enumerate(mock_callback.call_args_list):
|
mock_response, full_path, 'model.sft',
|
||||||
args, kwargs = call
|
mock_callback, interval=1.0
|
||||||
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
)
|
||||||
|
|
||||||
|
assert os.path.exists(full_path)
|
||||||
|
with open(full_path, 'rb') as f:
|
||||||
|
assert f.read() == b''.join(chunks)
|
||||||
|
os.remove(full_path)
|
||||||
|
|
||||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||||
@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
|
|||||||
assert last_call[0][1].status == "completed"
|
assert last_call[0][1].status == "completed"
|
||||||
assert last_call[0][1].progress_percentage == 100
|
assert last_call[0][1].progress_percentage == 100
|
||||||
|
|
||||||
def test_valid_subdirectory():
|
|
||||||
assert validate_model_subdirectory("valid-model123") is True
|
|
||||||
|
|
||||||
def test_subdirectory_too_long():
|
|
||||||
assert validate_model_subdirectory("a" * 51) is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_double_dots():
|
|
||||||
assert validate_model_subdirectory("model/../unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_slash():
|
|
||||||
assert validate_model_subdirectory("model/unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_special_characters():
|
|
||||||
assert validate_model_subdirectory("model@unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_underscore_and_dash():
|
|
||||||
assert validate_model_subdirectory("valid_model-name") is True
|
|
||||||
|
|
||||||
def test_empty_subdirectory():
|
|
||||||
assert validate_model_subdirectory("") is False
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("filename, expected", [
|
@pytest.mark.parametrize("filename, expected", [
|
||||||
("valid_model.safetensors", True),
|
("valid_model.safetensors", True),
|
||||||
("valid_model.sft", True),
|
("valid_model.sft", True),
|
||||||
|
Loading…
Reference in New Issue
Block a user