feat: support initialization

This commit is contained in:
arkohut 2024-08-29 17:47:14 +08:00
parent 42778c2f54
commit a448fd0c9a
3 changed files with 97 additions and 44 deletions

View File

@ -16,6 +16,8 @@ from tqdm import tqdm
from enum import Enum from enum import Enum
from magika import Magika from magika import Magika
from .config import settings from .config import settings
from .models import init_database
from .initialize_typesense import init_typesense
IS_THUMBNAIL = "is_thumbnail" IS_THUMBNAIL = "is_thumbnail"
@ -29,7 +31,7 @@ app.add_typer(lib_app, name="lib")
file_detector = Magika() file_detector = Magika()
BASE_URL = f"http://localhost:{settings.server_port}" BASE_URL = f"http://{settings.server_host}:{settings.server_port}"
ignore_files = [".DS_Store", ".screen_sequences", "worklog"] ignore_files = [".DS_Store", ".screen_sequences", "worklog"]
@ -84,7 +86,13 @@ def display_libraries(libraries):
@app.command() @app.command()
def serve(): def serve():
run_server() """Run the server after initializing if necessary."""
db_success = init_database()
ts_success = init_typesense()
if db_success and ts_success:
run_server()
else:
print("Server initialization failed. Unable to start the server.")
@lib_app.command("ls") @lib_app.command("ls")
@ -743,5 +751,16 @@ def unbind(
) )
@app.command()
def init():
"""Initialize the database and Typesense collection."""
db_success = init_database()
ts_success = init_typesense()
if db_success and ts_success:
print("Initialization completed successfully.")
else:
print("Initialization failed. Please check the error messages above.")
if __name__ == "__main__": if __name__ == "__main__":
app() app()

View File

@ -25,9 +25,27 @@ schema = {
{"name": "filename", "type": "string", "infix": True}, {"name": "filename", "type": "string", "infix": True},
{"name": "size", "type": "int32"}, {"name": "size", "type": "int32"},
{"name": "file_created_at", "type": "int64", "facet": False}, {"name": "file_created_at", "type": "int64", "facet": False},
{"name": "created_date", "type": "string", "facet": True, "optional": True, "sort": True}, {
{"name": "created_month", "type": "string", "facet": True, "optional": True, "sort": True}, "name": "created_date",
{"name": "created_year", "type": "string", "facet": True, "optional": True, "sort": True}, "type": "string",
"facet": True,
"optional": True,
"sort": True,
},
{
"name": "created_month",
"type": "string",
"facet": True,
"optional": True,
"sort": True,
},
{
"name": "created_year",
"type": "string",
"facet": True,
"optional": True,
"sort": True,
},
{"name": "file_last_modified_at", "type": "int64", "facet": False}, {"name": "file_last_modified_at", "type": "int64", "facet": False},
{"name": "file_type", "type": "string", "facet": True}, {"name": "file_type", "type": "string", "facet": True},
{"name": "file_type_group", "type": "string", "facet": True}, {"name": "file_type_group", "type": "string", "facet": True},
@ -39,7 +57,7 @@ schema = {
"type": "string[]", "type": "string[]",
"facet": True, "facet": True,
"optional": True, "optional": True,
"locale": "zh" "locale": "zh",
}, },
{ {
"name": "metadata_entries", "name": "metadata_entries",
@ -52,12 +70,12 @@ schema = {
"name": "embedding", "name": "embedding",
"type": "float[]", "type": "float[]",
"num_dim": settings.embedding.num_dim, "num_dim": settings.embedding.num_dim,
"optional": True, "optional": True,
}, },
{ {
"name": "image_embedding", "name": "image_embedding",
"type": "float[]", "type": "float[]",
"optional": True, "optional": True,
}, },
], ],
"token_separators": [":", "/", " ", "\\"], "token_separators": [":", "/", " ", "\\"],
@ -92,34 +110,39 @@ def update_collection_fields(client, schema):
) )
if __name__ == "__main__": def init_typesense():
import sys """Initialize the Typesense collection."""
force_recreate = "--force" in sys.argv
try: try:
# Check if the collection exists existing_collections = client.collections.retrieve()
existing_collection = client.collections[TYPESENSE_COLLECTION_NAME].retrieve() collection_names = [c["name"] for c in existing_collections]
if TYPESENSE_COLLECTION_NAME not in collection_names:
if force_recreate:
client.collections[TYPESENSE_COLLECTION_NAME].delete()
print(
f"Existing Typesense collection '{TYPESENSE_COLLECTION_NAME}' deleted successfully."
)
client.collections.create(schema) client.collections.create(schema)
print( print(
f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' recreated successfully." f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' created successfully."
) )
else: else:
# Update the fields of the existing collection
update_collection_fields(client, schema) update_collection_fields(client, schema)
print(
except typesense.exceptions.ObjectNotFound: f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' already exists. Updated fields if necessary."
# Collection doesn't exist, create it )
client.collections.create(schema)
print(
f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' created successfully."
)
except Exception as e: except Exception as e:
print(f"An error occurred: {str(e)}") print(f"Error initializing Typesense collection: {e}")
return False
return True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--force", action="store_true", help="Drop the collection before initializing")
args = parser.parse_args()
if args.force:
try:
client.collections[TYPESENSE_COLLECTION_NAME].delete()
print(f"Dropped collection '{TYPESENSE_COLLECTION_NAME}'.")
except Exception as e:
print(f"Error dropping collection: {e}")
init_typesense()

View File

@ -12,8 +12,10 @@ from sqlalchemy import (
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import relationship, DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import relationship, DeclarativeBase, Mapped, mapped_column
from typing import List from typing import List
from .config import get_database_path
from .schemas import MetadataSource, MetadataType from .schemas import MetadataSource, MetadataType
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from .config import get_database_path
class Base(DeclarativeBase): class Base(DeclarativeBase):
@ -154,6 +156,25 @@ class LibraryPluginModel(Base):
) )
def init_database():
"""Initialize the database."""
db_path = get_database_path()
engine = create_engine(f"sqlite:///{db_path}")
try:
Base.metadata.create_all(engine)
print(f"Database initialized successfully at {db_path}")
# Initialize default plugins
Session = sessionmaker(bind=engine)
with Session() as session:
initialize_default_plugins(session)
return True
except OperationalError as e:
print(f"Error initializing database: {e}")
return False
def initialize_default_plugins(session): def initialize_default_plugins(session):
default_plugins = [ default_plugins = [
PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"), PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"),
@ -165,14 +186,4 @@ def initialize_default_plugins(session):
if not existing_plugin: if not existing_plugin:
session.add(plugin) session.add(plugin)
session.commit() session.commit()
# Create the database engine with the path from config
engine = create_engine(f"sqlite:///{get_database_path()}")
Base.metadata.create_all(engine)
# Initialize default plugins
from sqlalchemy.orm import sessionmaker
Session = sessionmaker(bind=engine)
with Session() as session:
initialize_default_plugins(session)