From 549bcb6de4eb2f9df67e2c30595ab905b9b4a357 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:49:59 +0800 Subject: [PATCH] fix(typesense): move init client in function --- memos/commands.py | 17 +++++--- memos/crud.py | 44 ++++++++++++++----- memos/initialize_typesense.py | 81 +++++++++++++++++++---------------- 3 files changed, 88 insertions(+), 54 deletions(-) diff --git a/memos/commands.py b/memos/commands.py index c447ff4..1e34166 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -42,10 +42,11 @@ logging.getLogger("typer").setLevel(logging.ERROR) def serve(): """Run the server after initializing if necessary.""" db_success = init_database() - ts_success = init_typesense() - if db_success and ts_success: + ts_success = True + if settings.typesense.enabled: + ts_success = init_typesense() + if db_success and (ts_success or not settings.typesense.enabled): from .server import run_server - run_server() else: print("Server initialization failed. Unable to start the server.") @@ -53,10 +54,12 @@ def serve(): @app.command() def init(): - """Initialize the database and Typesense collection.""" + """Initialize the database and Typesense collection if enabled.""" db_success = init_database() - ts_success = init_typesense() - if db_success and ts_success: + ts_success = True + if settings.typesense.enabled: + ts_success = init_typesense() + if db_success and (ts_success or not settings.typesense.enabled): print("Initialization completed successfully.") else: print("Initialization failed. Please check the error messages above.") @@ -339,4 +342,4 @@ def disable(): if __name__ == "__main__": - app() + app() \ No newline at end of file diff --git a/memos/crud.py b/memos/crud.py index ac5df01..da87a21 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -30,8 +30,8 @@ from collections import defaultdict from .embedding import generate_embeddings import logging from sqlite_vec import serialize_float32 +import time -# 在文件顶部添加这行代码来设置日志记录器 logger = logging.getLogger(__name__) @@ -440,18 +440,19 @@ def full_text_search( if library_ids: library_ids_str = ", ".join(f"'{id}'" for id in library_ids) sql_query += f" AND entities.library_id IN ({library_ids_str})" - if start is not None and end is not None: sql_query += ( " AND strftime('%s', entities.file_created_at) BETWEEN :start AND :end" ) - params["start"] = start - params["end"] = end + params["start"] = str(start) + params["end"] = str(end) sql_query += " ORDER BY bm25(entities_fts) LIMIT :limit" result = db.execute(text(sql_query), params).fetchall() + logger.info(f"Full-text search sql: {sql_query}") + logger.info(f"Full-text search params: {params}") ids = [row[0] for row in result] logger.debug(f"Full-text search results: {ids}") return ids @@ -486,8 +487,8 @@ def vec_search( sql_query += ( " AND strftime('%s', entities.file_created_at) BETWEEN :start AND :end" ) - params["start"] = start - params["end"] = end + params["start"] = str(start) + params["end"] = str(end) sql_query += " AND K = :limit ORDER BY distance" @@ -521,18 +522,41 @@ def hybrid_search( start: Optional[int] = None, end: Optional[int] = None, ) -> List[Entity]: - fts_results = full_text_search(query, db, limit, library_ids, start, end) - vec_results = vec_search(query, db, limit, library_ids, start, end) + start_time = time.time() + fts_start = time.time() + fts_results = full_text_search(query, db, limit, library_ids, start, end) + fts_end = time.time() + logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds") + + vec_start = time.time() + vec_results = vec_search(query, db, limit, library_ids, start, end) + vec_end = time.time() + logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds") + + fusion_start = time.time() combined_results = reciprocal_rank_fusion(fts_results, vec_results) + fusion_end = time.time() + logger.info(f"Reciprocal rank fusion took {fusion_end - fusion_start:.4f} seconds") sorted_ids = [id for id, _ in combined_results][:limit] - logger.debug(f"Hybrid search results (sorted IDs): {sorted_ids}") + logger.info(f"Hybrid search results (sorted IDs): {sorted_ids}") + entities_start = time.time() entities = find_entities_by_ids(sorted_ids, db) + entities_end = time.time() + logger.info( + f"Finding entities by IDs took {entities_end - entities_start:.4f} seconds" + ) # Create a dictionary mapping entity IDs to entities entity_dict = {entity.id: entity for entity in entities} # Return entities in the order of sorted_ids - return [entity_dict[id] for id in sorted_ids if id in entity_dict] + result = [entity_dict[id] for id in sorted_ids if id in entity_dict] + + end_time = time.time() + total_time = end_time - start_time + logger.info(f"Total hybrid search time: {total_time:.4f} seconds") + + return result diff --git a/memos/initialize_typesense.py b/memos/initialize_typesense.py index 9456d4c..a5bb936 100644 --- a/memos/initialize_typesense.py +++ b/memos/initialize_typesense.py @@ -1,26 +1,11 @@ import typesense from .config import settings, TYPESENSE_COLLECTION_NAME import sys +import logging -# Check if Typesense is enabled -if not settings.typesense.enabled: - print("Error: Typesense is not enabled. Please enable it in the configuration.") - sys.exit(1) - -# Initialize Typesense client -client = typesense.Client( - { - "nodes": [ - { - "host": settings.typesense_host, - "port": settings.typesense_port, - "protocol": settings.typesense_protocol, - } - ], - "api_key": settings.typesense_api_key, - "connection_timeout_seconds": settings.typesense_connection_timeout_seconds, - } -) +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) # Define the schema for the Typesense collection schema = { @@ -87,7 +72,6 @@ schema = { "token_separators": [":", "/", " ", "\\"], } - def update_collection_fields(client, schema): existing_collection = client.collections[TYPESENSE_COLLECTION_NAME].retrieve() existing_fields = {field["name"]: field for field in existing_collection["fields"]} @@ -115,49 +99,72 @@ def update_collection_fields(client, schema): f"No new fields to add or update in the '{TYPESENSE_COLLECTION_NAME}' collection." ) - def init_typesense(): """Initialize the Typesense collection.""" if not settings.typesense.enabled: - print("Error: Typesense is not enabled. Please enable it in the configuration.") + logger.warning("Typesense is not enabled. Skipping initialization.") return False try: + client = typesense.Client( + { + "nodes": [ + { + "host": settings.typesense_host, + "port": settings.typesense_port, + "protocol": settings.typesense_protocol, + } + ], + "api_key": settings.typesense_api_key, + "connection_timeout_seconds": settings.typesense_connection_timeout_seconds, + } + ) + existing_collections = client.collections.retrieve() collection_names = [c["name"] for c in existing_collections] if TYPESENSE_COLLECTION_NAME not in collection_names: client.collections.create(schema) - print( - f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' created successfully." - ) + logger.info(f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' created successfully.") else: update_collection_fields(client, schema) - print( - f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' already exists. Updated fields if necessary." - ) + logger.info(f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' already exists. Updated fields if necessary.") + return True except Exception as e: - print(f"Error initializing Typesense collection: {e}") + logger.error(f"Error initializing Typesense collection: {e}") return False - return True - if __name__ == "__main__": import argparse - - if not settings.typesense.enabled: - print("Error: Typesense is not enabled. Please enable it in the configuration.") - sys.exit(1) + import sys parser = argparse.ArgumentParser() parser.add_argument("--force", action="store_true", help="Drop the collection before initializing") args = parser.parse_args() + if not settings.typesense.enabled: + logger.warning("Typesense is not enabled. Please enable it in the configuration if you want to use Typesense.") + sys.exit(0) + + client = typesense.Client( + { + "nodes": [ + { + "host": settings.typesense_host, + "port": settings.typesense_port, + "protocol": settings.typesense_protocol, + } + ], + "api_key": settings.typesense_api_key, + "connection_timeout_seconds": settings.typesense_connection_timeout_seconds, + } + ) + if args.force: try: client.collections[TYPESENSE_COLLECTION_NAME].delete() - print(f"Dropped collection '{TYPESENSE_COLLECTION_NAME}'.") + logger.info(f"Dropped collection '{TYPESENSE_COLLECTION_NAME}'.") except Exception as e: - print(f"Error dropping collection: {e}") + logger.error(f"Error dropping collection: {e}") if not init_typesense(): sys.exit(1)