From b3ebb2c92defc7d6c3531333f206509ea69c4d65 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 8 Oct 2024 18:54:59 +0800 Subject: [PATCH] feat(typesense): disable typesense by default --- memos/config.py | 22 ++++++++----- memos/initialize_typesense.py | 17 +++++++++- memos/models.py | 6 ++-- memos/server.py | 58 +++++++++++++++++++++++++---------- 4 files changed, 74 insertions(+), 29 deletions(-) diff --git a/memos/config.py b/memos/config.py index 90d522a..938db1d 100644 --- a/memos/config.py +++ b/memos/config.py @@ -38,6 +38,16 @@ class EmbeddingSettings(BaseModel): use_modelscope: bool = False +class TypesenseSettings(BaseModel): + enabled: bool = False + host: str = "localhost" + port: str = "8108" + protocol: str = "http" + api_key: str = "xyz" + connection_timeout_seconds: int = 10 + collection_name: str = "entities" + + class Settings(BaseSettings): model_config = SettingsConfigDict( yaml_file=str(Path.home() / ".memos" / "config.yaml"), @@ -50,13 +60,6 @@ class Settings(BaseSettings): default_library: str = "screenshots" screenshots_dir: str = os.path.join(base_dir, "screenshots") - typesense_host: str = "localhost" - typesense_port: str = "8108" - typesense_protocol: str = "http" - typesense_api_key: str = "xyz" - typesense_connection_timeout_seconds: int = 10 - typesense_collection_name: str = "entities" - # Server settings server_host: str = "0.0.0.0" server_port: int = 8080 @@ -70,6 +73,9 @@ class Settings(BaseSettings): # Embedding settings embedding: EmbeddingSettings = EmbeddingSettings() + # Typesense settings + typesense: TypesenseSettings = TypesenseSettings() + batchsize: int = 1 auth_username: str = "admin" @@ -136,7 +142,7 @@ settings = Settings() os.makedirs(settings.base_dir, exist_ok=True) # Global variable for Typesense collection name -TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name +TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name # Function to get the database path from environment variable or default diff --git a/memos/initialize_typesense.py b/memos/initialize_typesense.py index 292290b..9456d4c 100644 --- a/memos/initialize_typesense.py +++ b/memos/initialize_typesense.py @@ -1,5 +1,11 @@ import typesense from .config import settings, TYPESENSE_COLLECTION_NAME +import sys + +# Check if Typesense is enabled +if not settings.typesense.enabled: + print("Error: Typesense is not enabled. Please enable it in the configuration.") + sys.exit(1) # Initialize Typesense client client = typesense.Client( @@ -112,6 +118,10 @@ def update_collection_fields(client, schema): def init_typesense(): """Initialize the Typesense collection.""" + if not settings.typesense.enabled: + print("Error: Typesense is not enabled. Please enable it in the configuration.") + return False + try: existing_collections = client.collections.retrieve() collection_names = [c["name"] for c in existing_collections] @@ -134,6 +144,10 @@ def init_typesense(): if __name__ == "__main__": import argparse + if not settings.typesense.enabled: + print("Error: Typesense is not enabled. Please enable it in the configuration.") + sys.exit(1) + parser = argparse.ArgumentParser() parser.add_argument("--force", action="store_true", help="Drop the collection before initializing") args = parser.parse_args() @@ -145,4 +159,5 @@ if __name__ == "__main__": except Exception as e: print(f"Error dropping collection: {e}") - init_typesense() + if not init_typesense(): + sys.exit(1) diff --git a/memos/models.py b/memos/models.py index 5288307..b0bf4c1 100644 --- a/memos/models.py +++ b/memos/models.py @@ -223,9 +223,9 @@ def init_database(): with engine.connect() as conn: conn.execute( DDL( - """ + f""" CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0( - embedding float[768] + embedding float[{settings.embedding.num_dim}] ) """ ) @@ -430,4 +430,4 @@ def delete_fts_and_vec(mapper, connection, target): event.listen(EntityModel, "after_insert", update_fts_and_vec) event.listen(EntityModel, "after_update", update_fts_and_vec) -event.listen(EntityModel, "after_delete", delete_fts_and_vec) +event.listen(EntityModel, "after_delete", delete_fts_and_vec) \ No newline at end of file diff --git a/memos/server.py b/memos/server.py index bd8ed72..9c2ddcd 100644 --- a/memos/server.py +++ b/memos/server.py @@ -17,6 +17,7 @@ import json import cv2 from PIL import Image from secrets import compare_digest +import functools import typesense @@ -58,20 +59,22 @@ engine = create_engine(f"sqlite:///{get_database_path()}") event.listen(engine, "connect", load_extension) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -# Initialize Typesense client -client = typesense.Client( - { - "nodes": [ - { - "host": settings.typesense_host, - "port": settings.typesense_port, - "protocol": settings.typesense_protocol, - } - ], - "api_key": settings.typesense_api_key, - "connection_timeout_seconds": settings.typesense_connection_timeout_seconds, - } -) +# Initialize Typesense client only if enabled +client = None +if settings.typesense.enabled: + client = typesense.Client( + { + "nodes": [ + { + "host": settings.typesense.host, + "port": settings.typesense.port, + "protocol": settings.typesense.protocol, + } + ], + "api_key": settings.typesense.api_key, + "connection_timeout_seconds": settings.typesense.connection_timeout_seconds, + } + ) app.add_middleware( CORSMiddleware, @@ -370,11 +373,24 @@ async def update_entity( return entity +def typesense_required(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + if not settings.typesense.enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Typesense is not enabled", + ) + return await func(*args, **kwargs) + return wrapper + + @app.post( "/entities/{entity_id}/index", status_code=status.HTTP_204_NO_CONTENT, tags=["entity"], ) +@typesense_required async def sync_entity_to_typesense(entity_id: int, db: Session = Depends(get_db)): entity = crud.get_entity_by_id(entity_id, db) if entity is None: @@ -398,6 +414,7 @@ async def sync_entity_to_typesense(entity_id: int, db: Session = Depends(get_db) status_code=status.HTTP_204_NO_CONTENT, tags=["entity"], ) +@typesense_required async def batch_sync_entities_to_typesense( entity_ids: List[int], db: Session = Depends(get_db) ): @@ -423,6 +440,7 @@ async def batch_sync_entities_to_typesense( response_model=EntitySearchResult, tags=["entity"], ) +@typesense_required async def get_entity_index(entity_id: int) -> EntityIndexItem: try: entity_index_item = indexing.fetch_entity_by_id(client, entity_id) @@ -440,6 +458,7 @@ async def get_entity_index(entity_id: int) -> EntityIndexItem: status_code=status.HTTP_204_NO_CONTENT, tags=["entity"], ) +@typesense_required async def remove_entity_from_typesense(entity_id: int, db: Session = Depends(get_db)): try: indexing.remove_entity_by_id(client, entity_id) @@ -456,6 +475,7 @@ async def remove_entity_from_typesense(entity_id: int, db: Session = Depends(get response_model=List[EntityIndexItem], tags=["entity"], ) +@typesense_required def list_entitiy_indices_in_folder( library_id: int, folder_id: int, @@ -479,6 +499,7 @@ def list_entitiy_indices_in_folder( @app.get("/search/v2", response_model=SearchResult, tags=["search"]) +@typesense_required async def search_entities( q: str, library_ids: str = Query(None, description="Comma-separated list of library IDs"), @@ -821,9 +842,12 @@ async def search_entities_v2( def run_server(): print("Database path:", get_database_path()) - print( - f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}, Collection Name: {settings.typesense_collection_name}" - ) + if settings.typesense.enabled: + print( + f"Typesense connection info: Host: {settings.typesense.host}, Port: {settings.typesense.port}, Protocol: {settings.typesense.protocol}, Collection Name: {settings.typesense.collection_name}" + ) + else: + print("Typesense is disabled") print(f"VLM plugin enabled: {settings.vlm}") print(f"OCR plugin enabled: {settings.ocr}")