mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
fix(typesense): move init client in function
This commit is contained in:
parent
86a73ca72e
commit
549bcb6de4
@ -42,10 +42,11 @@ logging.getLogger("typer").setLevel(logging.ERROR)
|
||||
def serve():
|
||||
"""Run the server after initializing if necessary."""
|
||||
db_success = init_database()
|
||||
ts_success = init_typesense()
|
||||
if db_success and ts_success:
|
||||
ts_success = True
|
||||
if settings.typesense.enabled:
|
||||
ts_success = init_typesense()
|
||||
if db_success and (ts_success or not settings.typesense.enabled):
|
||||
from .server import run_server
|
||||
|
||||
run_server()
|
||||
else:
|
||||
print("Server initialization failed. Unable to start the server.")
|
||||
@ -53,10 +54,12 @@ def serve():
|
||||
|
||||
@app.command()
|
||||
def init():
|
||||
"""Initialize the database and Typesense collection."""
|
||||
"""Initialize the database and Typesense collection if enabled."""
|
||||
db_success = init_database()
|
||||
ts_success = init_typesense()
|
||||
if db_success and ts_success:
|
||||
ts_success = True
|
||||
if settings.typesense.enabled:
|
||||
ts_success = init_typesense()
|
||||
if db_success and (ts_success or not settings.typesense.enabled):
|
||||
print("Initialization completed successfully.")
|
||||
else:
|
||||
print("Initialization failed. Please check the error messages above.")
|
||||
@ -339,4 +342,4 @@ def disable():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
app()
|
@ -30,8 +30,8 @@ from collections import defaultdict
|
||||
from .embedding import generate_embeddings
|
||||
import logging
|
||||
from sqlite_vec import serialize_float32
|
||||
import time
|
||||
|
||||
# 在文件顶部添加这行代码来设置日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -440,18 +440,19 @@ def full_text_search(
|
||||
if library_ids:
|
||||
library_ids_str = ", ".join(f"'{id}'" for id in library_ids)
|
||||
sql_query += f" AND entities.library_id IN ({library_ids_str})"
|
||||
|
||||
if start is not None and end is not None:
|
||||
sql_query += (
|
||||
" AND strftime('%s', entities.file_created_at) BETWEEN :start AND :end"
|
||||
)
|
||||
params["start"] = start
|
||||
params["end"] = end
|
||||
params["start"] = str(start)
|
||||
params["end"] = str(end)
|
||||
|
||||
sql_query += " ORDER BY bm25(entities_fts) LIMIT :limit"
|
||||
|
||||
result = db.execute(text(sql_query), params).fetchall()
|
||||
|
||||
logger.info(f"Full-text search sql: {sql_query}")
|
||||
logger.info(f"Full-text search params: {params}")
|
||||
ids = [row[0] for row in result]
|
||||
logger.debug(f"Full-text search results: {ids}")
|
||||
return ids
|
||||
@ -486,8 +487,8 @@ def vec_search(
|
||||
sql_query += (
|
||||
" AND strftime('%s', entities.file_created_at) BETWEEN :start AND :end"
|
||||
)
|
||||
params["start"] = start
|
||||
params["end"] = end
|
||||
params["start"] = str(start)
|
||||
params["end"] = str(end)
|
||||
|
||||
sql_query += " AND K = :limit ORDER BY distance"
|
||||
|
||||
@ -521,18 +522,41 @@ def hybrid_search(
|
||||
start: Optional[int] = None,
|
||||
end: Optional[int] = None,
|
||||
) -> List[Entity]:
|
||||
fts_results = full_text_search(query, db, limit, library_ids, start, end)
|
||||
vec_results = vec_search(query, db, limit, library_ids, start, end)
|
||||
start_time = time.time()
|
||||
|
||||
fts_start = time.time()
|
||||
fts_results = full_text_search(query, db, limit, library_ids, start, end)
|
||||
fts_end = time.time()
|
||||
logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds")
|
||||
|
||||
vec_start = time.time()
|
||||
vec_results = vec_search(query, db, limit, library_ids, start, end)
|
||||
vec_end = time.time()
|
||||
logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds")
|
||||
|
||||
fusion_start = time.time()
|
||||
combined_results = reciprocal_rank_fusion(fts_results, vec_results)
|
||||
fusion_end = time.time()
|
||||
logger.info(f"Reciprocal rank fusion took {fusion_end - fusion_start:.4f} seconds")
|
||||
|
||||
sorted_ids = [id for id, _ in combined_results][:limit]
|
||||
logger.debug(f"Hybrid search results (sorted IDs): {sorted_ids}")
|
||||
logger.info(f"Hybrid search results (sorted IDs): {sorted_ids}")
|
||||
|
||||
entities_start = time.time()
|
||||
entities = find_entities_by_ids(sorted_ids, db)
|
||||
entities_end = time.time()
|
||||
logger.info(
|
||||
f"Finding entities by IDs took {entities_end - entities_start:.4f} seconds"
|
||||
)
|
||||
|
||||
# Create a dictionary mapping entity IDs to entities
|
||||
entity_dict = {entity.id: entity for entity in entities}
|
||||
|
||||
# Return entities in the order of sorted_ids
|
||||
return [entity_dict[id] for id in sorted_ids if id in entity_dict]
|
||||
result = [entity_dict[id] for id in sorted_ids if id in entity_dict]
|
||||
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
logger.info(f"Total hybrid search time: {total_time:.4f} seconds")
|
||||
|
||||
return result
|
||||
|
@ -1,26 +1,11 @@
|
||||
import typesense
|
||||
from .config import settings, TYPESENSE_COLLECTION_NAME
|
||||
import sys
|
||||
import logging
|
||||
|
||||
# Check if Typesense is enabled
|
||||
if not settings.typesense.enabled:
|
||||
print("Error: Typesense is not enabled. Please enable it in the configuration.")
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize Typesense client
|
||||
client = typesense.Client(
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"host": settings.typesense_host,
|
||||
"port": settings.typesense_port,
|
||||
"protocol": settings.typesense_protocol,
|
||||
}
|
||||
],
|
||||
"api_key": settings.typesense_api_key,
|
||||
"connection_timeout_seconds": settings.typesense_connection_timeout_seconds,
|
||||
}
|
||||
)
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define the schema for the Typesense collection
|
||||
schema = {
|
||||
@ -87,7 +72,6 @@ schema = {
|
||||
"token_separators": [":", "/", " ", "\\"],
|
||||
}
|
||||
|
||||
|
||||
def update_collection_fields(client, schema):
|
||||
existing_collection = client.collections[TYPESENSE_COLLECTION_NAME].retrieve()
|
||||
existing_fields = {field["name"]: field for field in existing_collection["fields"]}
|
||||
@ -115,49 +99,72 @@ def update_collection_fields(client, schema):
|
||||
f"No new fields to add or update in the '{TYPESENSE_COLLECTION_NAME}' collection."
|
||||
)
|
||||
|
||||
|
||||
def init_typesense():
|
||||
"""Initialize the Typesense collection."""
|
||||
if not settings.typesense.enabled:
|
||||
print("Error: Typesense is not enabled. Please enable it in the configuration.")
|
||||
logger.warning("Typesense is not enabled. Skipping initialization.")
|
||||
return False
|
||||
|
||||
try:
|
||||
client = typesense.Client(
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"host": settings.typesense_host,
|
||||
"port": settings.typesense_port,
|
||||
"protocol": settings.typesense_protocol,
|
||||
}
|
||||
],
|
||||
"api_key": settings.typesense_api_key,
|
||||
"connection_timeout_seconds": settings.typesense_connection_timeout_seconds,
|
||||
}
|
||||
)
|
||||
|
||||
existing_collections = client.collections.retrieve()
|
||||
collection_names = [c["name"] for c in existing_collections]
|
||||
if TYPESENSE_COLLECTION_NAME not in collection_names:
|
||||
client.collections.create(schema)
|
||||
print(
|
||||
f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' created successfully."
|
||||
)
|
||||
logger.info(f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' created successfully.")
|
||||
else:
|
||||
update_collection_fields(client, schema)
|
||||
print(
|
||||
f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' already exists. Updated fields if necessary."
|
||||
)
|
||||
logger.info(f"Typesense collection '{TYPESENSE_COLLECTION_NAME}' already exists. Updated fields if necessary.")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error initializing Typesense collection: {e}")
|
||||
logger.error(f"Error initializing Typesense collection: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
if not settings.typesense.enabled:
|
||||
print("Error: Typesense is not enabled. Please enable it in the configuration.")
|
||||
sys.exit(1)
|
||||
import sys
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--force", action="store_true", help="Drop the collection before initializing")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not settings.typesense.enabled:
|
||||
logger.warning("Typesense is not enabled. Please enable it in the configuration if you want to use Typesense.")
|
||||
sys.exit(0)
|
||||
|
||||
client = typesense.Client(
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"host": settings.typesense_host,
|
||||
"port": settings.typesense_port,
|
||||
"protocol": settings.typesense_protocol,
|
||||
}
|
||||
],
|
||||
"api_key": settings.typesense_api_key,
|
||||
"connection_timeout_seconds": settings.typesense_connection_timeout_seconds,
|
||||
}
|
||||
)
|
||||
|
||||
if args.force:
|
||||
try:
|
||||
client.collections[TYPESENSE_COLLECTION_NAME].delete()
|
||||
print(f"Dropped collection '{TYPESENSE_COLLECTION_NAME}'.")
|
||||
logger.info(f"Dropped collection '{TYPESENSE_COLLECTION_NAME}'.")
|
||||
except Exception as e:
|
||||
print(f"Error dropping collection: {e}")
|
||||
logger.error(f"Error dropping collection: {e}")
|
||||
|
||||
if not init_typesense():
|
||||
sys.exit(1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user