diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index c88cf958..abf8e3f7 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -149,6 +149,28 @@ async def test_download_model_url_request_failure(): # Verify that the get method was called with the correct URL mock_get.assert_called_once_with('http://example.com/model.safetensors') +@pytest.mark.asyncio +async def test_download_model_invalid_model_subdirectory(): + + mock_make_request = AsyncMock() + mock_progress_callback = AsyncMock() + + + result = await download_model( + mock_make_request, + 'model.bin', + 'http://example.com/model.bin', + '../bad_path', + mock_progress_callback + ) + + # Assert the result + assert isinstance(result, DownloadModelResult) + assert result.message == 'Invalid model subdirectory' + assert result.status == 'error' + assert result.already_existed is False + + # For create_model_path function def test_create_model_path(tmp_path, monkeypatch): mock_models_dir = tmp_path / "models"