mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 10:53:29 +00:00
Add sqlite db
This commit is contained in:
parent
1709a8441e
commit
b6b475191d
97
app/database.py
Normal file
97
app/database.py
Normal file
@ -0,0 +1,97 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from queue import Queue, Empty, Full
|
||||
import threading
|
||||
from app.database_updater import DatabaseUpdater
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, database_path=None, pool_size=1):
|
||||
if database_path is None:
|
||||
database_path = "file::memory:?cache=shared"
|
||||
|
||||
self.database_path = database_path
|
||||
self.pool_size = pool_size
|
||||
# Store connections in a pool, default to 1 as normal usage is going to be from a single thread at a time
|
||||
self.connection_pool: Queue = Queue(maxsize=pool_size)
|
||||
self._db_lock = threading.Lock()
|
||||
self._initialized = False
|
||||
self._closing = False
|
||||
|
||||
def _setup(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
with self._db_lock:
|
||||
if not self._initialized:
|
||||
self._make_db()
|
||||
self._initialized = True
|
||||
|
||||
def _create_connection(self):
|
||||
# TODO: Catch error for sqlite lib missing on linux
|
||||
logging.info(f"Creating connection to {self.database_path}")
|
||||
conn = sqlite3.connect(self.database_path, check_same_thread=False)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
return conn
|
||||
|
||||
def _make_db(self):
|
||||
with self._get_connection() as con:
|
||||
updater = DatabaseUpdater(con)
|
||||
updater.update()
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
if self._closing:
|
||||
raise Exception("Database is shutting down")
|
||||
|
||||
try:
|
||||
# Try to get connection from pool
|
||||
connection = self.connection_pool.get_nowait()
|
||||
except Empty:
|
||||
# Create new connection if pool is empty
|
||||
connection = self._create_connection()
|
||||
|
||||
try:
|
||||
yield connection
|
||||
finally:
|
||||
try:
|
||||
# Try to add to pool if it's empty
|
||||
self.connection_pool.put_nowait(connection)
|
||||
except Full:
|
||||
# Pool is full, close the connection
|
||||
connection.close()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
# Setup the database if it's not already initialized
|
||||
self._setup()
|
||||
with self._get_connection() as connection:
|
||||
yield connection
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
# Drain and close all connections in the pool
|
||||
self._closing = True
|
||||
while True:
|
||||
try:
|
||||
conn = self.connection_pool.get_nowait()
|
||||
conn.close()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Create a global instance
|
||||
db_path = None
|
||||
if not args.memory_database:
|
||||
db_path = folder_paths.get_user_directory() + "/comfyui.db"
|
||||
db = Database(db_path)
|
54
app/database_updater.py
Normal file
54
app/database_updater.py
Normal file
@ -0,0 +1,54 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
class DatabaseUpdater:
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self.current_version = self.get_db_version()
|
||||
self.version_updates = {
|
||||
1: self._update_to_v1,
|
||||
}
|
||||
self.max_version = max(self.version_updates.keys())
|
||||
self.update_required = self.current_version < self.max_version
|
||||
logging.info(f"Database version: {self.current_version}")
|
||||
|
||||
def get_db_version(self):
|
||||
return self.connection.execute("PRAGMA user_version").fetchone()[0]
|
||||
|
||||
def update(self):
|
||||
if not self.update_required:
|
||||
return
|
||||
|
||||
logging.info(f"Updating database: {self.current_version} -> {self.max_version}")
|
||||
|
||||
dirname = os.path.dirname(__file__)
|
||||
cursor = self.connection.cursor()
|
||||
for version in range(self.current_version + 1, self.max_version + 1):
|
||||
filename = os.path.join(dirname, f"db/v{version}.sql")
|
||||
if not os.path.exists(filename):
|
||||
raise Exception(
|
||||
f"Database update script for version {version} not found"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
sql = file.read()
|
||||
cursor.executescript(sql)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to execute update script for version {version}: {e}"
|
||||
)
|
||||
|
||||
method = self.version_updates[version]
|
||||
if method is not None:
|
||||
method(cursor)
|
||||
|
||||
cursor.execute("PRAGMA user_version = %d" % self.max_version)
|
||||
self.connection.commit()
|
||||
cursor.close()
|
||||
self.current_version = self.get_db_version()
|
||||
|
||||
def _update_to_v1(self, cursor):
|
||||
# TODO: migrate users and settings
|
||||
print("Updating to v1")
|
25
app/db/v1.sql
Normal file
25
app/db/v1.sql
Normal file
@ -0,0 +1,25 @@
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
models (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
path TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
hash TEXT NOT NULL,
|
||||
source_url TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
tags (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
model_tags (
|
||||
model_id INTEGER NOT NULL,
|
||||
tag_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (model_id, tag_id),
|
||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE
|
||||
);
|
@ -143,6 +143,8 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.")
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user