Add FrontendManager to manage non-default front-end impl (#3897)

* Add frontend manager

* Add tests

* nit

* Add unit test to github CI

* Fix path

* nit

* ignore

* Add logging

* Install test deps

* Remove 'stable' keyword support

* Update test

* Add web-root arg

* Rename web-root to front-end-root

* Add test on non-exist version number

* Use repo owner/name to replace hard coded provider list

* Inline cmd args

* nit

* Fix unit test
This commit is contained in:
Chenlei Hu 2024-07-16 11:26:11 -04:00 committed by GitHub
parent 33346fd9b8
commit 99458e8aca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 350 additions and 6 deletions

View File

@ -24,3 +24,7 @@ jobs:
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit

1
.gitignore vendored
View File

@ -18,3 +18,4 @@ venv/
/tests-ui/data/object_info.json
/user/
*.log
web_custom_versions/

0
app/__init__.py Normal file
View File

187
app/frontend_management.py Normal file
View File

@ -0,0 +1,187 @@
import argparse
import logging
import os
import re
import tempfile
import zipfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict
import requests
from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING
REQUEST_TIMEOUT = 10 # seconds
class Asset(TypedDict):
url: str
class Release(TypedDict):
id: int
tag_name: str
name: str
prerelease: bool
created_at: str
published_at: str
body: str
assets: NotRequired[list[Asset]]
@dataclass
class FrontEndProvider:
owner: str
repo: str
@property
def folder_name(self) -> str:
return f"{self.owner}_{self.repo}"
@property
def release_url(self) -> str:
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
@cached_property
def all_releases(self) -> list[Release]:
releases = []
api_url = self.release_url
while api_url:
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
releases.extend(response.json())
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
if "next" in response.links:
api_url = response.links["next"]["url"]
else:
api_url = None
return releases
@cached_property
def latest_release(self) -> Release:
latest_release_url = f"{self.release_url}/latest"
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()
def get_release(self, version: str) -> Release:
if version == "latest":
return self.latest_release
else:
for release in self.all_releases:
if release["tag_name"] in [version, f"v{version}"]:
return release
raise ValueError(f"Version {version} not found in releases")
def download_release_asset_zip(release: Release, destination_path: str) -> None:
"""Download dist.zip from github release."""
asset_url = None
for asset in release.get("assets", []):
if asset["name"] == "dist.zip":
asset_url = asset["url"]
break
if not asset_url:
raise ValueError("dist.zip not found in the release assets")
# Use a temporary file to download the zip content
with tempfile.TemporaryFile() as tmp_file:
headers = {"Accept": "application/octet-stream"}
response = requests.get(
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
)
response.raise_for_status() # Ensure we got a successful response
# Write the content to the temporary file
tmp_file.write(response.content)
# Go back to the beginning of the temporary file
tmp_file.seek(0)
# Extract the zip file content to the destination path
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
zip_ref.extractall(destination_path)
class FrontendManager:
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Args:
value (str): The version string to parse.
Returns:
tuple[str, str]: A tuple containing provider name and version.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(\d+\.\d+\.\d+|latest)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(cls, version_string: str) -> str:
"""
Initializes the frontend for the specified version.
Args:
version_string (str): The version string.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH
repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
semantic_version = release["tag_name"].lstrip("v")
web_root = str(
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
return web_root
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.
Args:
version_string (str): The version string to initialize the frontend with.
Returns:
str: The path of the initialized frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH

View File

@ -1,7 +1,10 @@
import argparse
import enum
import os
from typing import Optional
import comfy.options
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
@ -124,6 +127,38 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
parser.add_argument(
"--front-end-version",
type=str,
default=DEFAULT_VERSION_STRING,
help="""
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
download available frontend implementations from GitHub releases.
The version string should be in the format of:
[repoOwner]/[repoName]@[version]
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
""",
)
def is_valid_directory(path: Optional[str]) -> Optional[str]:
"""Validate if the given path is a directory."""
if path is None:
return None
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
return path
parser.add_argument(
"--front-end-root",
type=is_valid_directory,
default=None,
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
if comfy.options.args_parsing:
args = parser.parse_args()

View File

@ -1,5 +1,8 @@
[pytest]
markers =
inference: mark as inference test (deselect with '-m "not inference"')
testpaths = tests
testpaths =
tests
tests-unit
addopts = -s
pythonpath = .

View File

@ -25,9 +25,10 @@ import mimetypes
from comfy.cli_args import args
import comfy.utils
import comfy.model_management
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
@ -83,8 +84,12 @@ class PromptServer():
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict()
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web")
self.web_root = (
FrontendManager.init_frontend(args.front_end_version)
if args.front_end_root is None
else args.front_end_root
)
logging.info(f"[Prompt Server] web root: {self.web_root}")
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None

8
tests-unit/README.md Normal file
View File

@ -0,0 +1,8 @@
# Pytest Unit Tests
## Install test dependencies
`pip install -r tests-units/requirements.txt`
## Run tests
`pytest tests-units/`

View File

View File

@ -0,0 +1,100 @@
import argparse
import pytest
from requests.exceptions import HTTPError
from app.frontend_management import (
FrontendManager,
FrontEndProvider,
Release,
)
from comfy.cli_args import DEFAULT_VERSION_STRING
@pytest.fixture
def mock_releases():
return [
Release(
id=1,
tag_name="1.0.0",
name="Release 1.0.0",
prerelease=False,
created_at="2022-01-01T00:00:00Z",
published_at="2022-01-01T00:00:00Z",
body="Release notes for 1.0.0",
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
),
Release(
id=2,
tag_name="2.0.0",
name="Release 2.0.0",
prerelease=False,
created_at="2022-02-01T00:00:00Z",
published_at="2022-02-01T00:00:00Z",
body="Release notes for 2.0.0",
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
),
]
@pytest.fixture
def mock_provider(mock_releases):
provider = FrontEndProvider(
owner="test-owner",
repo="test-repo",
)
provider.all_releases = mock_releases
provider.latest_release = mock_releases[1]
FrontendManager.PROVIDERS = [provider]
return provider
def test_get_release(mock_provider, mock_releases):
version = "1.0.0"
release = mock_provider.get_release(version)
assert release == mock_releases[0]
def test_get_release_latest(mock_provider, mock_releases):
version = "latest"
release = mock_provider.get_release(version)
assert release == mock_releases[1]
def test_get_release_invalid_version(mock_provider):
version = "invalid"
with pytest.raises(ValueError):
mock_provider.get_release(version)
def test_init_frontend_default():
version_string = DEFAULT_VERSION_STRING
frontend_path = FrontendManager.init_frontend(version_string)
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
def test_init_frontend_invalid_version():
version_string = "test-owner/test-repo@1.100.99"
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)
def test_init_frontend_invalid_provider():
version_string = "invalid/invalid@latest"
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)
def test_parse_version_string():
version_string = "owner/repo@1.0.0"
repo_owner, repo_name, version = FrontendManager.parse_version_string(
version_string
)
assert repo_owner == "owner"
assert repo_name == "repo"
assert version == "1.0.0"
def test_parse_version_string_invalid():
version_string = "invalid"
with pytest.raises(argparse.ArgumentTypeError):
FrontendManager.parse_version_string(version_string)

View File

@ -0,0 +1 @@
pytest>=7.8.0