feat: init default library

This commit is contained in:
arkohut 2024-09-09 19:57:07 +08:00
parent 8056f19773
commit d7e6c32e86
3 changed files with 46 additions and 16 deletions

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import os import os
import time
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path

View File

@ -47,6 +47,8 @@ class Settings(BaseSettings):
base_dir: str = str(Path.home() / ".memos") base_dir: str = str(Path.home() / ".memos")
database_path: str = os.path.join(base_dir, "database.db") database_path: str = os.path.join(base_dir, "database.db")
default_library: str = "screenshots"
typesense_host: str = "localhost" typesense_host: str = "localhost"
typesense_port: str = "8108" typesense_port: str = "8108"
typesense_protocol: str = "http" typesense_protocol: str = "http"

View File

@ -15,7 +15,7 @@ from typing import List
from .schemas import MetadataSource, MetadataType from .schemas import MetadataSource, MetadataType
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from .config import get_database_path from .config import get_database_path, settings
class Base(DeclarativeBase): class Base(DeclarativeBase):
@ -75,16 +75,14 @@ class EntityModel(Base):
"FolderModel", back_populates="entities" "FolderModel", back_populates="entities"
) )
metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship( metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship(
"EntityMetadataModel", "EntityMetadataModel", lazy="joined", cascade="all, delete-orphan"
lazy="joined",
cascade="all, delete-orphan"
) )
tags: Mapped[List["TagModel"]] = relationship( tags: Mapped[List["TagModel"]] = relationship(
"TagModel", "TagModel",
secondary="entity_tags", secondary="entity_tags",
lazy="joined", lazy="joined",
cascade="all, delete", cascade="all, delete",
overlaps="entities" overlaps="entities",
) )
# 添加索引 # 添加索引
@ -160,30 +158,61 @@ def init_database():
"""Initialize the database.""" """Initialize the database."""
db_path = get_database_path() db_path = get_database_path()
engine = create_engine(f"sqlite:///{db_path}") engine = create_engine(f"sqlite:///{db_path}")
try: try:
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
print(f"Database initialized successfully at {db_path}") print(f"Database initialized successfully at {db_path}")
# Initialize default plugins # Initialize default plugins
Session = sessionmaker(bind=engine) Session = sessionmaker(bind=engine)
with Session() as session: with Session() as session:
initialize_default_plugins(session) default_plugins = initialize_default_plugins(session)
init_default_libraries(session, default_plugins)
return True return True
except OperationalError as e: except OperationalError as e:
print(f"Error initializing database: {e}") print(f"Error initializing database: {e}")
return False return False
def initialize_default_plugins(session): def initialize_default_plugins(session):
default_plugins = [ default_plugins = [
PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"), PluginModel(
PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"), name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"
),
PluginModel(
name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"
),
] ]
for plugin in default_plugins: for plugin in default_plugins:
existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first() existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first()
if not existing_plugin: if not existing_plugin:
session.add(plugin) session.add(plugin)
session.commit() session.commit()
return default_plugins
def init_default_libraries(session, default_plugins):
default_libraries = [
LibraryModel(name=settings.default_library),
]
for library in default_libraries:
existing_library = (
session.query(LibraryModel).filter_by(name=library.name).first()
)
if not existing_library:
session.add(library)
for plugin in default_plugins:
bind_response = session.query(PluginModel).filter_by(name=plugin.name).first()
if bind_response:
library_plugin = LibraryPluginModel(
library_id=1, plugin_id=bind_response.id
) # Assuming library_id=1 for default libraries
session.add(library_plugin)
session.commit()