fix(typesense): move init client in function

This commit is contained in:
arkohut 2024-10-08 19:49:59 +08:00
parent 86a73ca72e
commit 549bcb6de4
3 changed files with 88 additions and 54 deletions

View File

@ -42,10 +42,11 @@ logging.getLogger("typer").setLevel(logging.ERROR)
def serve():
"""Run the server after initializing if necessary."""
db_success = init_database()
ts_success = init_typesense()
if db_success and ts_success:
ts_success = True
if settings.typesense.enabled:
ts_success = init_typesense()
if db_success and (ts_success or not settings.typesense.enabled):
from .server import run_server
run_server()
else:
print("Server initialization failed. Unable to start the server.")
@ -53,10 +54,12 @@ def serve():
@app.command()
def init():
"""Initialize the database and Typesense collection."""
"""Initialize the database and Typesense collection if enabled."""
db_success = init_database()
ts_success = init_typesense()
if db_success and ts_success:
ts_success = True
if settings.typesense.enabled:
ts_success = init_typesense()
if db_success and (ts_success or not settings.typesense.enabled):
print("Initialization completed successfully.")
else:
print("Initialization failed. Please check the error messages above.")
@ -339,4 +342,4 @@ def disable():
if __name__ == "__main__":
app()
app()

View File

@ -30,8 +30,8 @@ from collections import defaultdict
from .embedding import generate_embeddings
import logging
from sqlite_vec import serialize_float32
import time
# 在文件顶部添加这行代码来设置日志记录器
logger = logging.getLogger(__name__)
@ -440,18 +440,19 @@ def full_text_search(
if library_ids:
library_ids_str = ", ".join(f"'{id}'" for id in library_ids)
sql_query += f" AND entities.library_id IN ({library_ids_str})"
if start is not None and end is not None:
sql_query += (
" AND strftime('%s', entities.file_created_at) BETWEEN :start AND :end"
)
params["start"] = start
params["end"] = end
params["start"] = str(start)
params["end"] = str(end)
sql_query += " ORDER BY bm25(entities_fts) LIMIT :limit"
result = db.execute(text(sql_query), params).fetchall()
logger.info(f"Full-text search sql: {sql_query}")
logger.info(f"Full-text search params: {params}")
ids = [row[0] for row in result]
logger.debug(f"Full-text search results: {ids}")
return ids
@ -486,8 +487,8 @@ def vec_search(
sql_query += (
" AND strftime('%s', entities.file_created_at) BETWEEN :start AND :end"
)
params["start"] = start
params["end"] = end
params["start"] = str(start)
params["end"] = str(end)
sql_query += " AND K = :limit ORDER BY distance"
@ -521,18 +522,41 @@ def hybrid_search(
start: Optional[int] = None,
end: Optional[int] = None,
) -> List[Entity]:
fts_results = full_text_search(query, db, limit, library_ids, start, end)
vec_results = vec_search(query, db, limit, library_ids, start, end)
start_time = time.time()
fts_start = time.time()
fts_results = full_text_search(query, db, limit, library_ids, start, end)
fts_end = time.time()
logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds")
vec_start = time.time()
vec_results = vec_search(query, db, limit, library_ids, start, end)
vec_end = time.time()
logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds")
fusion_start = time.time()
combined_results = reciprocal_rank_fusion(fts_results, vec_results)
fusion_end = time.time()
logger.info(f"Reciprocal rank fusion took {fusion_end - fusion_start:.4f} seconds")
sorted_ids = [id for id, _ in combined_results][:limit]
logger.debug(f"Hybrid search results (sorted IDs): {sorted_ids}")
logger.info(f"Hybrid search results (sorted IDs): {sorted_ids}")
entities_start = time.time()
entities = find_entities_by_ids(sorted_ids, db)
entities_end = time.time()
logger.info(
f"Finding entities by IDs took {entities_end - entities_start:.4f} seconds"
)
# Create a dictionary mapping entity IDs to entities
entity_dict = {entity.id: entity for entity in entities}
# Return entities in the order of sorted_ids
return [entity_dict[id] for id in sorted_ids if id in entity_dict]
result = [entity_dict[id] for id in sorted_ids if id in entity_dict]
end_time = time.time()
total_time = end_time - start_time
logger.info(f"Total hybrid search time: {total_time:.4f} seconds")
return result

View File

@ -1,26 +1,11 @@
import typesense
from .config import settings, TYPESENSE_COLLECTION_NAME
import sys
import logging
# Check if Typesense is enabled
if not settings.typesense.enabled:
print("Error: Typesense is not enabled. Please enable it in the configuration.")
sys.exit(1)
# Initialize Typesense client
client = typesense.Client(
{
"nodes": [
{
"host": settings.typesense_host,
"port": settings.typesense_port,
"protocol": settings.typesense_protocol,
}
],
"api_key": settings.typesense_api_key,
"connection_timeout_seconds": settings.typesense_connection_timeout_seconds,
}
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define the schema for the Typesense collection
schema = {
@ -87,7 +72,6 @@ schema = {
"token_separators": [":", "/", " ", "\\"],
}
def update_collection_fields(client, schema):
existing_collection = client.collections[TYPESENSE_COLLECTION_NAME].retrieve()
existing_fields = {field["name"]: field for field in existing_collection["fields"]}
@ -115,49 +99,72 @@ def update_collection_fields(client, schema):
f"No new fields to add or update in the '{TYPESENSE_COLLECTION_NAME}' collection."
)
def init_typesense():
"""Initialize the Typesense collection."""
if not settings.typesense.enabled:
print("Error: Typesense is not enabled. Please enable it in the configuration.")
logger.warning("Typesense is not enabled. Skipping initialization.")
return False
try:
client = typesense.Client(
{
"nodes": [
{
"host": settings.typesense_host,
"port": settings.typesense_port,
"protocol": settings.typesense_protocol,
}
],
"api_key": settings.typesense_api_key,
"connection_timeout_seconds": settings.typesense_connection_timeout_seconds,
}
)
existing_collections = client.collections.retrieve()
collection_names = [c["name"] for c in existing_collections]
if TYPESENSE_COLLECTION_NAME not in collection_names:
client.collections.create(schema)
print(
f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' created successfully."
)
logger.info(f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' created successfully.")
else:
update_collection_fields(client, schema)
print(
f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' already exists. Updated fields if necessary."
)
logger.info(f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' already exists. Updated fields if necessary.")
return True
except Exception as e:
print(f"Error initializing Typesense collection: {e}")
logger.error(f"Error initializing Typesense collection: {e}")
return False
return True
if __name__ == "__main__":
import argparse
if not settings.typesense.enabled:
print("Error: Typesense is not enabled. Please enable it in the configuration.")
sys.exit(1)
import sys
parser = argparse.ArgumentParser()
parser.add_argument("--force", action="store_true", help="Drop the collection before initializing")
args = parser.parse_args()
if not settings.typesense.enabled:
logger.warning("Typesense is not enabled. Please enable it in the configuration if you want to use Typesense.")
sys.exit(0)
client = typesense.Client(
{
"nodes": [
{
"host": settings.typesense_host,
"port": settings.typesense_port,
"protocol": settings.typesense_protocol,
}
],
"api_key": settings.typesense_api_key,
"connection_timeout_seconds": settings.typesense_connection_timeout_seconds,
}
)
if args.force:
try:
client.collections[TYPESENSE_COLLECTION_NAME].delete()
print(f"Dropped collection '{TYPESENSE_COLLECTION_NAME}'.")
logger.info(f"Dropped collection '{TYPESENSE_COLLECTION_NAME}'.")
except Exception as e:
print(f"Error dropping collection: {e}")
logger.error(f"Error dropping collection: {e}")
if not init_typesense():
sys.exit(1)