From b6b475191d4874d66bacb4f2f1d9cdea12c6ad87 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 30 Jan 2025 21:48:53 +0000 Subject: [PATCH] Add sqlite db --- app/database.py | 97 +++++++++++++++++++++++++++++++++++++++++ app/database_updater.py | 54 +++++++++++++++++++++++ app/db/v1.sql | 25 +++++++++++ comfy/cli_args.py | 2 + 4 files changed, 178 insertions(+) create mode 100644 app/database.py create mode 100644 app/database_updater.py create mode 100644 app/db/v1.sql diff --git a/app/database.py b/app/database.py new file mode 100644 index 00000000..6bbb14eb --- /dev/null +++ b/app/database.py @@ -0,0 +1,97 @@ +import logging +import sqlite3 +from contextlib import contextmanager +from queue import Queue, Empty, Full +import threading +from app.database_updater import DatabaseUpdater +import folder_paths +from comfy.cli_args import args + + +class Database: + def __init__(self, database_path=None, pool_size=1): + if database_path is None: + database_path = "file::memory:?cache=shared" + + self.database_path = database_path + self.pool_size = pool_size + # Store connections in a pool, default to 1 as normal usage is going to be from a single thread at a time + self.connection_pool: Queue = Queue(maxsize=pool_size) + self._db_lock = threading.Lock() + self._initialized = False + self._closing = False + + def _setup(self): + if self._initialized: + return + + with self._db_lock: + if not self._initialized: + self._make_db() + self._initialized = True + + def _create_connection(self): + # TODO: Catch error for sqlite lib missing on linux + logging.info(f"Creating connection to {self.database_path}") + conn = sqlite3.connect(self.database_path, check_same_thread=False) + conn.execute("PRAGMA foreign_keys = ON") + return conn + + def _make_db(self): + with self._get_connection() as con: + updater = DatabaseUpdater(con) + updater.update() + + @contextmanager + def _get_connection(self): + if self._closing: + raise Exception("Database is shutting down") + + try: + # Try to get connection from pool + connection = self.connection_pool.get_nowait() + except Empty: + # Create new connection if pool is empty + connection = self._create_connection() + + try: + yield connection + finally: + try: + # Try to add to pool if it's empty + self.connection_pool.put_nowait(connection) + except Full: + # Pool is full, close the connection + connection.close() + + @contextmanager + def get_connection(self): + # Setup the database if it's not already initialized + self._setup() + with self._get_connection() as connection: + yield connection + + def close(self): + if self._closing: + return + # Drain and close all connections in the pool + self._closing = True + while True: + try: + conn = self.connection_pool.get_nowait() + conn.close() + except Empty: + break + + def __del__(self): + try: + self.close() + except: + pass + + +# Create a global instance +db_path = None +if not args.memory_database: + db_path = folder_paths.get_user_directory() + "/comfyui.db" +db = Database(db_path) diff --git a/app/database_updater.py b/app/database_updater.py new file mode 100644 index 00000000..58acb9d6 --- /dev/null +++ b/app/database_updater.py @@ -0,0 +1,54 @@ +import logging +import os + + +class DatabaseUpdater: + def __init__(self, connection): + self.connection = connection + self.current_version = self.get_db_version() + self.version_updates = { + 1: self._update_to_v1, + } + self.max_version = max(self.version_updates.keys()) + self.update_required = self.current_version < self.max_version + logging.info(f"Database version: {self.current_version}") + + def get_db_version(self): + return self.connection.execute("PRAGMA user_version").fetchone()[0] + + def update(self): + if not self.update_required: + return + + logging.info(f"Updating database: {self.current_version} -> {self.max_version}") + + dirname = os.path.dirname(__file__) + cursor = self.connection.cursor() + for version in range(self.current_version + 1, self.max_version + 1): + filename = os.path.join(dirname, f"db/v{version}.sql") + if not os.path.exists(filename): + raise Exception( + f"Database update script for version {version} not found" + ) + + try: + with open(filename, "r") as file: + sql = file.read() + cursor.executescript(sql) + except Exception as e: + raise Exception( + f"Failed to execute update script for version {version}: {e}" + ) + + method = self.version_updates[version] + if method is not None: + method(cursor) + + cursor.execute("PRAGMA user_version = %d" % self.max_version) + self.connection.commit() + cursor.close() + self.current_version = self.get_db_version() + + def _update_to_v1(self, cursor): + # TODO: migrate users and settings + print("Updating to v1") diff --git a/app/db/v1.sql b/app/db/v1.sql new file mode 100644 index 00000000..fd0783a5 --- /dev/null +++ b/app/db/v1.sql @@ -0,0 +1,25 @@ +CREATE TABLE IF NOT EXISTS + models ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + path TEXT NOT NULL, + name TEXT NOT NULL, + model TEXT NOT NULL, + type TEXT NOT NULL, + hash TEXT NOT NULL, + source_url TEXT + ); + +CREATE TABLE IF NOT EXISTS + tags ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE + ); + +CREATE TABLE IF NOT EXISTS + model_tags ( + model_id INTEGER NOT NULL, + tag_id INTEGER NOT NULL, + PRIMARY KEY (model_id, tag_id), + FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE, + FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE + ); diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 812798bf..672e5d84 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -143,6 +143,8 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") +parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.") + # The default built-in provider hosted under web/ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"