diff --git a/app/user_manager.py b/app/user_manager.py index 20817844..4e545c0d 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -4,15 +4,31 @@ import re import uuid import glob import shutil +import logging from aiohttp import web from urllib import parse from comfy.cli_args import args import folder_paths from .app_settings import AppSettings +from typing import TypedDict default_user = "default" +class FileInfo(TypedDict): + path: str + size: int + modified: int + + +def get_file_info(path: str, relative_to: str) -> FileInfo: + return { + "path": os.path.relpath(path, relative_to).replace(os.sep, '/'), + "size": os.path.getsize(path), + "modified": os.path.getmtime(path) + } + + class UserManager(): def __init__(self): user_directory = folder_paths.get_user_directory() @@ -154,6 +170,7 @@ class UserManager(): recurse = request.rel_url.query.get('recurse', '').lower() == "true" full_info = request.rel_url.query.get('full_info', '').lower() == "true" + split_path = request.rel_url.query.get('split', '').lower() == "true" # Use different patterns based on whether we're recursing or not if recurse: @@ -161,26 +178,21 @@ class UserManager(): else: pattern = os.path.join(glob.escape(path), '*') - results = glob.glob(pattern, recursive=recurse) + def process_full_path(full_path: str) -> FileInfo | str | list[str]: + if full_info: + return get_file_info(full_path, path) - if full_info: - results = [ - { - 'path': os.path.relpath(x, path).replace(os.sep, '/'), - 'size': os.path.getsize(x), - 'modified': os.path.getmtime(x) - } for x in results if os.path.isfile(x) - ] - else: - results = [ - os.path.relpath(x, path).replace(os.sep, '/') - for x in results - if os.path.isfile(x) - ] + rel_path = os.path.relpath(full_path, path).replace(os.sep, '/') + if split_path: + return [rel_path] + rel_path.split('/') - split_path = request.rel_url.query.get('split', '').lower() == "true" - if split_path and not full_info: - results = [[x] + x.split('/') for x in results] + return rel_path + + results = [ + process_full_path(full_path) + for full_path in glob.glob(pattern, recursive=recurse) + if os.path.isfile(full_path) + ] return web.json_response(results) @@ -208,20 +220,51 @@ class UserManager(): @routes.post("/userdata/{file}") async def post_userdata(request): + """ + Upload or update a user data file. + + This endpoint handles file uploads to a user's data directory, with options for + controlling overwrite behavior and response format. + + Query Parameters: + - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true". + - full_info (optional): If "true", returns detailed file information (path, size, modified time). + If "false", returns only the relative file path. + + Path Parameters: + - file: The target file path (URL encoded if necessary). + + Returns: + - 400: If 'file' parameter is missing. + - 403: If the requested path is not allowed. + - 409: If overwrite=false and the file already exists. + - 200: JSON response with either: + - Full file information (if full_info=true) + - Relative file path (if full_info=false) + + The request body should contain the raw file content to be written. + """ path = get_user_data_path(request) if not isinstance(path, str): return path - overwrite = request.query["overwrite"] != "false" + overwrite = request.query.get("overwrite", 'true') != "false" + full_info = request.query.get('full_info', 'false').lower() == "true" + if not overwrite and os.path.exists(path): - return web.Response(status=409) + return web.Response(status=409, text="File already exists") body = await request.read() with open(path, "wb") as f: f.write(body) - resp = os.path.relpath(path, self.get_request_user_filepath(request, None)) + user_path = self.get_request_user_filepath(request, None) + if full_info: + resp = get_file_info(path, user_path) + else: + resp = os.path.relpath(path, user_path) + return web.json_response(resp) @routes.delete("/userdata/{file}") @@ -236,6 +279,30 @@ class UserManager(): @routes.post("/userdata/{file}/move/{dest}") async def move_userdata(request): + """ + Move or rename a user data file. + + This endpoint handles moving or renaming files within a user's data directory, with options for + controlling overwrite behavior and response format. + + Path Parameters: + - file: The source file path (URL encoded if necessary) + - dest: The destination file path (URL encoded if necessary) + + Query Parameters: + - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true". + - full_info (optional): If "true", returns detailed file information (path, size, modified time). + If "false", returns only the relative file path. + + Returns: + - 400: If either 'file' or 'dest' parameter is missing + - 403: If either requested path is not allowed + - 404: If the source file does not exist + - 409: If overwrite=false and the destination file already exists + - 200: JSON response with either: + - Full file information (if full_info=true) + - Relative file path (if full_info=false) + """ source = get_user_data_path(request, check_exists=True) if not isinstance(source, str): return source @@ -244,12 +311,19 @@ class UserManager(): if not isinstance(source, str): return dest - overwrite = request.query["overwrite"] != "false" - if not overwrite and os.path.exists(dest): - return web.Response(status=409) + overwrite = request.query.get("overwrite", 'true') != "false" + full_info = request.query.get('full_info', 'false').lower() == "true" - print(f"moving '{source}' -> '{dest}'") + if not overwrite and os.path.exists(dest): + return web.Response(status=409, text="File already exists") + + logging.info(f"moving '{source}' -> '{dest}'") shutil.move(source, dest) - resp = os.path.relpath(dest, self.get_request_user_filepath(request, None)) + user_path = self.get_request_user_filepath(request, None) + if full_info: + resp = get_file_info(dest, user_path) + else: + resp = os.path.relpath(dest, user_path) + return web.json_response(resp) diff --git a/tests-unit/prompt_server_test/user_manager_test.py b/tests-unit/prompt_server_test/user_manager_test.py index 936c6bd2..7e523cbf 100644 --- a/tests-unit/prompt_server_test/user_manager_test.py +++ b/tests-unit/prompt_server_test/user_manager_test.py @@ -14,7 +14,7 @@ def user_manager(tmp_path): um = UserManager() um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join( tmp_path, file - ) + ) if file else tmp_path return um @@ -80,9 +80,7 @@ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path): client = await aiohttp_client(app) resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true") assert resp.status == 200 - assert await resp.json() == [ - ["subdir/file1.txt", "subdir", "file1.txt"] - ] + assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]] async def test_listuserdata_invalid_directory(aiohttp_client, app): @@ -118,3 +116,116 @@ async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path): assert "/" in result[0]["path"] # Ensure forward slash is used assert "\\" not in result[0]["path"] # Ensure backslash is not present assert result[0]["path"] == "subdir/file1.txt" + + +async def test_post_userdata_new_file(aiohttp_client, app, tmp_path): + client = await aiohttp_client(app) + content = b"test content" + resp = await client.post("/userdata/test.txt", data=content) + + assert resp.status == 200 + assert await resp.text() == '"test.txt"' + + # Verify file was created with correct content + with open(tmp_path / "test.txt", "rb") as f: + assert f.read() == content + + +async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path): + # Create initial file + with open(tmp_path / "test.txt", "w") as f: + f.write("initial content") + + client = await aiohttp_client(app) + new_content = b"updated content" + resp = await client.post("/userdata/test.txt", data=new_content) + + assert resp.status == 200 + assert await resp.text() == '"test.txt"' + + # Verify file was overwritten + with open(tmp_path / "test.txt", "rb") as f: + assert f.read() == new_content + + +async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path): + # Create initial file + with open(tmp_path / "test.txt", "w") as f: + f.write("initial content") + + client = await aiohttp_client(app) + resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content") + + assert resp.status == 409 + + # Verify original content unchanged + with open(tmp_path / "test.txt", "r") as f: + assert f.read() == "initial content" + + +async def test_post_userdata_full_info(aiohttp_client, app, tmp_path): + client = await aiohttp_client(app) + content = b"test content" + resp = await client.post("/userdata/test.txt?full_info=true", data=content) + + assert resp.status == 200 + result = await resp.json() + assert result["path"] == "test.txt" + assert result["size"] == len(content) + assert "modified" in result + + +async def test_move_userdata(aiohttp_client, app, tmp_path): + # Create initial file + with open(tmp_path / "source.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.post("/userdata/source.txt/move/dest.txt") + + assert resp.status == 200 + assert await resp.text() == '"dest.txt"' + + # Verify file was moved + assert not os.path.exists(tmp_path / "source.txt") + with open(tmp_path / "dest.txt", "r") as f: + assert f.read() == "test content" + + +async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path): + # Create source and destination files + with open(tmp_path / "source.txt", "w") as f: + f.write("source content") + with open(tmp_path / "dest.txt", "w") as f: + f.write("destination content") + + client = await aiohttp_client(app) + resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false") + + assert resp.status == 409 + + # Verify files remain unchanged + with open(tmp_path / "source.txt", "r") as f: + assert f.read() == "source content" + with open(tmp_path / "dest.txt", "r") as f: + assert f.read() == "destination content" + + +async def test_move_userdata_full_info(aiohttp_client, app, tmp_path): + # Create initial file + with open(tmp_path / "source.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true") + + assert resp.status == 200 + result = await resp.json() + assert result["path"] == "dest.txt" + assert result["size"] == len("test content") + assert "modified" in result + + # Verify file was moved + assert not os.path.exists(tmp_path / "source.txt") + with open(tmp_path / "dest.txt", "r") as f: + assert f.read() == "test content"