mirror of
synced 2025-03-18 08:07:07 +00:00

* Add model downloading endpoint.
* Move client session init to async function.
* Break up large function.
* Send "download_progress" as websocket event.
* Fixed
* Fixed.
* Use async mock.
* Move server set up to right before run call.
* Validate that model subdirectory cannot contain relative paths.
* Add download_model test checking for invalid paths.
* Remove DS_Store.
* Consolidate DownloadStatus and DownloadModelResult
* Add progress_interval as an optional parameter.
* Use tuple type from annotations.
* Use pydantic.
* Update comment.
* Revert "Use pydantic."
This reverts commit 7461e8eb00
* Add new line.
* Add newline EOF.
* Validate model filename as well.
* Add comment to not reply on internal.
* Restrict downloading to safetensor files only.
322 lines
12 KiB
322 lines
12 KiB
import pytest
import aiohttp
from aiohttp import ClientResponse
import itertools
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
class AsyncIteratorMock:
A mock class that simulates an asynchronous iterator.
This is used to mimic the behavior of aiohttp's content iterator.
def __init__(self, seq):
# Convert the input sequence into an iterator
self.iter = iter(seq)
def __aiter__(self):
# This method is called when 'async for' is used
return self
async def __anext__(self):
# This method is called for each iteration in an 'async for' loop
return next(self.iter)
except StopIteration:
# This is the asynchronous equivalent of StopIteration
raise StopAsyncIteration
class ContentMock:
A mock class that simulates the content attribute of an aiohttp ClientResponse.
This class provides the iter_chunked method which returns an async iterator of chunks.
def __init__(self, chunks):
# Store the chunks that will be returned by the iterator
self.chunks = chunks
def iter_chunked(self, chunk_size):
# This method mimics aiohttp's content.iter_chunked()
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
return AsyncIteratorMock(self.chunks)
async def test_download_model_success():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'}
# Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
mock_response.content = ContentMock(chunks)
mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock file operations
mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Successfully downloaded model.sft'
assert result.status == 'completed'
assert result.already_existed is False
# Check progress callback calls
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
# Check initial call
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
# Check final call
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
# Verify file writing
mock_file.write.assert_any_call(b'a' * 500)
mock_file.write.assert_any_call(b'b' * 300)
mock_file.write.assert_any_call(b'c' * 200)
# Verify request was made
async def test_download_model_url_request_failure():
# Mock dependencies
mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
# Mock the check_file_exists function to return None (file doesn't exist)
with patch('model_filemanager.check_file_exists', return_value=None):
# Call the function
result = await download_model(
# Assert the expected behavior
assert isinstance(result, DownloadModelStatus)
assert result.status == 'error'
assert result.message == 'Failed to download model.safetensors. Status code: 404'
assert result.already_existed is False
# Check that progress_callback was called with the correct arguments
message='Starting download of model.safetensors',
message='Failed to download model.safetensors. Status code: 404',
# Verify that the get method was called with the correct URL
async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Invalid model subdirectory'
assert result.status == 'error'
assert result.already_existed is False
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
model_name = "test_model.sft"
model_directory = "test_dir"
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
assert os.path.exists(os.path.dirname(file_path))
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.sft already exists"
assert result.already_existed is True
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
assert result is None
async def test_track_download_progress_no_content_length():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('builtins.open', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=0.1
assert result.status == "completed"
# Check that progress was reported even without knowing the total size
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
async def test_track_download_progress_interval():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
# Create a mock time function that returns incremental float values
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=1.0
# Print out the actual call count and the arguments of each call for debugging
print(f"mock_callback was called {mock_callback.call_count} times")
for i, call in enumerate(mock_callback.call_args_list):
args, kwargs = call
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
# Verify the first and last calls
first_call = mock_callback.call_args_list[0]
assert first_call[0][1].status == "in_progress"
# Allow for some initial progress, but it should be less than 50%
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False
@pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True),
("valid_model.sft", True),
("valid model.safetensors", True), # Test with space
("model_with.multiple.dots.pt", False),
("", False), # Empty string
("../../../etc/passwd", False), # Path traversal attempt
("/etc/passwd", False), # Absolute path
("\\windows\\system32\\config\\sam", False), # Windows path
(".hidden_file.pt", False), # Hidden file
("invalid<char>.ckpt", False), # Invalid character
("invalid?.ckpt", False), # Another invalid character
("very" * 100 + ".safetensors", False), # Too long filename
("\nmodel_with_newline.pt", False), # Newline character
("model_with_emoji😊.pt", False), # Emoji in filename
def test_validate_filename(filename, expected):
assert validate_filename(filename) == expected