diff --git a/memos/cmds/library.py b/memos/cmds/library.py index aece10b..dc46759 100644 --- a/memos/cmds/library.py +++ b/memos/cmds/library.py @@ -507,7 +507,10 @@ async def add_entity( post_response = await client.post( f"{BASE_URL}/libraries/{library_id}/entities", json=new_entity, - params={"plugins": plugins} if plugins else {}, + params={ + "plugins": plugins, + "update_index": "true" + } if plugins else {"update_index": "true"}, timeout=60, ) if 200 <= post_response.status_code < 300: @@ -546,6 +549,7 @@ async def update_entity( json=new_entity, params={ "trigger_webhooks_flag": "true", + "update_index": "true", **({"plugins": plugins} if plugins else {}), }, timeout=60, @@ -581,6 +585,7 @@ def reindex( force: bool = typer.Option( False, "--force", help="Force recreate FTS and vector tables before reindexing" ), + batch_size: int = typer.Option(1, "--batch-size", "-bs", help="Batch size for processing entities"), ): print(f"Reindexing library {library_id}") @@ -608,7 +613,7 @@ def reindex( recreate_fts_and_vec_tables() print("FTS and vector tables have been recreated.") - with httpx.Session() as client: + with httpx.Client() as client: total_entities = 0 # Get total entity count for all folders @@ -647,56 +652,34 @@ def reindex( if not entities: break - # Update last_scan_at for each entity - for entity in entities: - if entity["id"] in scanned_entities: - continue - - update_response = client.post( - f"{BASE_URL}/entities/{entity['id']}/last-scan-at" - ) - if update_response.status_code != 204: - print( - f"Failed to update last_scan_at for entity {entity['id']}: {update_response.status_code} - {update_response.text}" + # 收集需要处理的实体 ID + entity_ids = [ + entity["id"] + for entity in entities + if entity["id"] not in scanned_entities + ] + + # 按 batch_size 分批处理 + for i in range(0, len(entity_ids), batch_size): + batch_ids = entity_ids[i:i + batch_size] + if batch_ids: + batch_response = client.post( + f"{BASE_URL}/entities/batch-index", + json={"entity_ids": batch_ids}, + timeout=60, ) - else: - scanned_entities.add(entity["id"]) - - pbar.update(1) + if batch_response.status_code != 204: + print( + f"Failed to update batch: {batch_response.status_code} - {batch_response.text}" + ) + pbar.update(len(batch_ids)) + scanned_entities.update(batch_ids) offset += limit print(f"Reindexing completed for library {library_id}") -async def check_and_index_entity(client, entity_id, entity_last_scan_at): - try: - index_response = client.get(f"{BASE_URL}/entities/{entity_id}/index") - if index_response.status_code == 200: - index_data = index_response.json() - if index_data["last_scan_at"] is None: - return entity_last_scan_at is not None - index_last_scan_at = datetime.fromtimestamp(index_data["last_scan_at"]) - entity_last_scan_at = datetime.fromisoformat(entity_last_scan_at) - - if index_last_scan_at >= entity_last_scan_at: - return False # Index is up to date, no need to update - return True # Index doesn't exist or needs update - except httpx.HTTPStatusError as e: - if e.response.status_code == 404: - return True # Index doesn't exist, need to create - raise # Re-raise other HTTP errors - - -async def index_batch(client, entity_ids): - index_response = client.post( - f"{BASE_URL}/entities/batch-index", - json=entity_ids, - timeout=60, - ) - return index_response - - @lib_app.command("sync") def sync( library_id: int, @@ -799,7 +782,10 @@ def sync( update_response = httpx.put( f"{BASE_URL}/entities/{existing_entity['id']}", json=new_entity, - params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, + params={ + "trigger_webhooks_flag": str(not without_webhooks).lower(), + "update_index": "true", + }, timeout=60, ) if update_response.status_code == 200: @@ -829,7 +815,10 @@ def sync( create_response = httpx.post( f"{BASE_URL}/libraries/{library_id}/entities", json=new_entity, - params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, + params={ + "trigger_webhooks_flag": str(not without_webhooks).lower(), + "update_index": "true", + }, timeout=60, ) diff --git a/memos/commands.py b/memos/commands.py index d0320f8..665d9a3 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -191,7 +191,10 @@ def scan_default_library( def reindex_default_library( force: bool = typer.Option( False, "--force", help="Force recreate FTS and vector tables before reindexing" - ) + ), + batch_size: int = typer.Option( + 1, "--batch-size", "-bs", help="Batch size for processing files" + ), ): """ Reindex the default library for memos. @@ -215,7 +218,7 @@ def reindex_default_library( # Reindex the library print(f"Reindexing library: {default_library['name']}") - reindex(default_library["id"], force=force, folders=None) + reindex(default_library["id"], force=force, folders=None, batch_size=batch_size) @app.command("record") diff --git a/memos/crud.py b/memos/crud.py index 9e35954..ad5d2c4 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -25,13 +25,12 @@ from .models import ( EntityMetadataModel, EntityTagModel, ) -import numpy as np from collections import defaultdict from .embedding import get_embeddings import logging from sqlite_vec import serialize_float32 import time -import asyncio +import json logger = logging.getLogger(__name__) @@ -100,7 +99,11 @@ def add_folders(library_id: int, folders: NewFoldersParam, db: Session) -> Libra return Library(**db_library.__dict__) -def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity: +def create_entity( + library_id: int, + entity: NewEntityParam, + db: Session, +) -> Entity: tags = entity.tags metadata_entries = entity.metadata_entries @@ -146,6 +149,7 @@ def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entit db.add(entity_metadata) db.commit() db.refresh(db_entity) + return Entity(**db_entity.__dict__) @@ -185,15 +189,11 @@ def remove_entity(entity_id: int, db: Session): entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first() if entity: # Delete the entity from FTS and vec tables first + db.execute(text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity_id}) db.execute( - text("DELETE FROM entities_fts WHERE id = :id"), - {"id": entity_id} + text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity_id} ) - db.execute( - text("DELETE FROM entities_vec WHERE rowid = :id"), - {"id": entity_id} - ) - + # Then delete the entity itself db.delete(entity) db.commit() @@ -241,7 +241,9 @@ def find_entities_by_ids(entity_ids: List[int], db: Session) -> List[Entity]: def update_entity( - entity_id: int, updated_entity: UpdateEntityParam, db: Session + entity_id: int, + updated_entity: UpdateEntityParam, + db: Session, ) -> Entity: db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first() @@ -298,6 +300,7 @@ def update_entity( db.commit() db.refresh(db_entity) + return Entity(**db_entity.__dict__) @@ -312,14 +315,18 @@ def touch_entity(entity_id: int, db: Session) -> bool: return False -def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity: +def update_entity_tags( + entity_id: int, + tags: List[str], + db: Session, +) -> Entity: db_entity = get_entity_by_id(entity_id, db) if not db_entity: raise ValueError(f"Entity with id {entity_id} not found") # Clear existing tags db.query(EntityTagModel).filter(EntityTagModel.entity_id == entity_id).delete() - + for tag_name in tags: tag = db.query(TagModel).filter(TagModel.name == tag_name).first() if not tag: @@ -333,12 +340,13 @@ def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity: source=MetadataSource.PLUGIN_GENERATED, ) db.add(entity_tag) - + # Update last_scan_at in the same transaction db_entity.last_scan_at = func.now() - + db.commit() db.refresh(db_entity) + return Entity(**db_entity.__dict__) @@ -363,17 +371,20 @@ def add_new_tags(entity_id: int, tags: List[str], db: Session) -> Entity: source=MetadataSource.PLUGIN_GENERATED, ) db.add(entity_tag) - + # Update last_scan_at in the same transaction db_entity.last_scan_at = func.now() - + db.commit() db.refresh(db_entity) + return Entity(**db_entity.__dict__) def update_entity_metadata_entries( - entity_id: int, updated_metadata: List[EntityMetadataParam], db: Session + entity_id: int, + updated_metadata: List[EntityMetadataParam], + db: Session, ) -> Entity: db_entity = get_entity_by_id(entity_id, db) @@ -421,6 +432,7 @@ def update_entity_metadata_entries( db.commit() db.refresh(db_entity) + return Entity(**db_entity.__dict__) @@ -478,7 +490,9 @@ def full_text_search( params["start"] = str(start) params["end"] = str(end) - sql_query += " ORDER BY bm25(entities_fts), entities.file_created_at DESC LIMIT :limit" + sql_query += ( + " ORDER BY bm25(entities_fts), entities.file_created_at DESC LIMIT :limit" + ) result = db.execute(text(sql_query), params).fetchall() @@ -489,7 +503,7 @@ def full_text_search( return ids -async def vec_search( +def vec_search( query: str, db: Session, limit: int = 200, @@ -497,7 +511,7 @@ async def vec_search( start: Optional[int] = None, end: Optional[int] = None, ) -> List[int]: - query_embedding = await get_embeddings([query]) + query_embedding = get_embeddings([query]) if not query_embedding: return [] @@ -517,9 +531,7 @@ async def vec_search( sql_query += f" AND entities.library_id IN ({library_ids_str})" if start is not None and end is not None: - sql_query += ( - " AND strftime('%s', entities.file_created_at, 'utc') BETWEEN :start AND :end" - ) + sql_query += " AND strftime('%s', entities.file_created_at, 'utc') BETWEEN :start AND :end" params["start"] = str(start) params["end"] = str(end) @@ -547,7 +559,7 @@ def reciprocal_rank_fusion( return sorted_results -async def hybrid_search( +def hybrid_search( query: str, db: Session, limit: int = 200, @@ -563,7 +575,7 @@ async def hybrid_search( logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds") vec_start = time.time() - vec_results = await vec_search(query, db, limit, library_ids, start, end) + vec_results = vec_search(query, db, limit, library_ids, start, end) vec_end = time.time() logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds") @@ -595,7 +607,7 @@ async def hybrid_search( return result -async def list_entities( +def list_entities( db: Session, limit: int = 200, library_ids: Optional[List[int]] = None, @@ -609,7 +621,7 @@ async def list_entities( if start is not None and end is not None: query = query.filter( - func.strftime("%s", EntityModel.file_created_at, 'utc').between( + func.strftime("%s", EntityModel.file_created_at, "utc").between( str(start), str(end) ) ) @@ -635,10 +647,10 @@ def get_entity_context( ) .first() ) - + if not target_entity: return [], [] - + # Get previous entities prev_entities = [] if prev > 0: @@ -646,7 +658,7 @@ def get_entity_context( db.query(EntityModel) .filter( EntityModel.library_id == library_id, - EntityModel.file_created_at < target_entity.file_created_at + EntityModel.file_created_at < target_entity.file_created_at, ) .order_by(EntityModel.file_created_at.desc()) .limit(prev) @@ -654,7 +666,7 @@ def get_entity_context( ) # Reverse the list to get chronological order and convert to Entity models prev_entities = [Entity(**entity.__dict__) for entity in prev_entities][::-1] - + # Get next entities next_entities = [] if next > 0: @@ -662,7 +674,7 @@ def get_entity_context( db.query(EntityModel) .filter( EntityModel.library_id == library_id, - EntityModel.file_created_at > target_entity.file_created_at + EntityModel.file_created_at > target_entity.file_created_at, ) .order_by(EntityModel.file_created_at.asc()) .limit(next) @@ -670,5 +682,171 @@ def get_entity_context( ) # Convert to Entity models next_entities = [Entity(**entity.__dict__) for entity in next_entities] - + return prev_entities, next_entities + + +def process_ocr_result(value, max_length=4096): + try: + ocr_data = json.loads(value) + if isinstance(ocr_data, list) and all( + isinstance(item, dict) + and "dt_boxes" in item + and "rec_txt" in item + and "score" in item + for item in ocr_data + ): + return " ".join(item["rec_txt"] for item in ocr_data[:max_length]) + else: + return json.dumps(ocr_data, indent=2) + except json.JSONDecodeError: + return value + + +def prepare_fts_data(entity: EntityModel) -> tuple[str, str]: + tags = ", ".join([tag.name for tag in entity.tags]) + fts_metadata = "\n".join( + [ + f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}" + for entry in entity.metadata_entries + ] + ) + return tags, fts_metadata + + +def prepare_vec_data(entity: EntityModel) -> str: + vec_metadata = "\n".join( + [ + f"{entry.key}: {entry.value}" + for entry in entity.metadata_entries + if entry.key != "ocr_result" + ] + ) + ocr_result = next( + (entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"), + "", + ) + vec_metadata += f"\nocr_result: {process_ocr_result(ocr_result, max_length=128)}" + return vec_metadata + + +def update_entity_index(entity: EntityModel, db: Session): + """Update both FTS and vector indexes for an entity""" + try: + # Update FTS index + tags, fts_metadata = prepare_fts_data(entity) + db.execute( + text( + """ + INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata) + VALUES(:id, :filepath, :tags, :metadata) + """ + ), + { + "id": entity.id, + "filepath": entity.filepath, + "tags": tags, + "metadata": fts_metadata, + }, + ) + + # Update vector index + vec_metadata = prepare_vec_data(entity) + embeddings = get_embeddings([vec_metadata]) + + if embeddings and embeddings[0]: + db.execute( + text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id} + ) + db.execute( + text( + """ + INSERT INTO entities_vec (rowid, embedding) + VALUES (:id, :embedding) + """ + ), + { + "id": entity.id, + "embedding": serialize_float32(embeddings[0]), + }, + ) + + db.commit() + except Exception as e: + logger.error(f"Error updating indexes for entity {entity.id}: {e}") + db.rollback() + raise + + +def batch_update_entity_indices(entity_ids: List[int], db: Session): + """Batch update both FTS and vector indexes for multiple entities""" + try: + # 获取实体 + entities = db.query(EntityModel).filter(EntityModel.id.in_(entity_ids)).all() + found_ids = {entity.id for entity in entities} + + # 检查是否所有请求的实体都找到了 + missing_ids = set(entity_ids) - found_ids + if missing_ids: + raise ValueError(f"Entities not found: {missing_ids}") + + # Prepare FTS data for all entities + fts_data = [] + vec_metadata_list = [] + + for entity in entities: + # Prepare FTS data + tags, fts_metadata = prepare_fts_data(entity) + fts_data.append((entity.id, entity.filepath, tags, fts_metadata)) + + # Prepare vector data + vec_metadata = prepare_vec_data(entity) + vec_metadata_list.append(vec_metadata) + + # Batch update FTS table + for entity_id, filepath, tags, metadata in fts_data: + db.execute( + text( + """ + INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata) + VALUES(:id, :filepath, :tags, :metadata) + """ + ), + { + "id": entity_id, + "filepath": filepath, + "tags": tags, + "metadata": metadata, + }, + ) + + # Batch get embeddings + embeddings = get_embeddings(vec_metadata_list) + + # Batch update vector table + if embeddings: + for entity, embedding in zip(entities, embeddings): + if embedding: # Check if embedding is not empty + db.execute( + text("DELETE FROM entities_vec WHERE rowid = :id"), + {"id": entity.id}, + ) + db.execute( + text( + """ + INSERT INTO entities_vec (rowid, embedding) + VALUES (:id, :embedding) + """ + ), + { + "id": entity.id, + "embedding": serialize_float32(embedding), + }, + ) + + db.commit() + except Exception as e: + logger.error(f"Error batch updating indexes: {e}") + db.rollback() + raise + diff --git a/memos/embedding.py b/memos/embedding.py index 73b5f19..267bb6e 100644 --- a/memos/embedding.py +++ b/memos/embedding.py @@ -58,11 +58,11 @@ def generate_embeddings(texts: List[str]) -> List[List[float]]: return embeddings.tolist() -async def get_embeddings(texts: List[str]) -> List[List[float]]: +def get_embeddings(texts: List[str]) -> List[List[float]]: if settings.embedding.use_local: embeddings = generate_embeddings(texts) else: - embeddings = await get_remote_embeddings(texts) + embeddings = get_remote_embeddings(texts) # Round the embedding values to 5 decimal places return [ @@ -71,12 +71,12 @@ async def get_embeddings(texts: List[str]) -> List[List[float]]: ] -async def get_remote_embeddings(texts: List[str]) -> List[List[float]]: +def get_remote_embeddings(texts: List[str]) -> List[List[float]]: payload = {"model": settings.embedding.model, "input": texts} - async with httpx.AsyncClient() as client: + with httpx.Client() as client: try: - response = await client.post(settings.embedding.endpoint, json=payload) + response = client.post(settings.embedding.endpoint, json=payload) response.raise_for_status() result = response.json() return result["embeddings"] diff --git a/memos/models.py b/memos/models.py index 0205195..149459c 100644 --- a/memos/models.py +++ b/memos/models.py @@ -21,11 +21,6 @@ from sqlalchemy import text import sqlite_vec import sys from pathlib import Path -import json -from .embedding import get_embeddings -from sqlite_vec import serialize_float32 -import asyncio -import threading class Base(DeclarativeBase): @@ -246,7 +241,13 @@ def recreate_fts_and_vec_tables(): def init_database(): """Initialize the database.""" db_path = get_database_path() - engine = create_engine(f"sqlite:///{db_path}") + engine = create_engine( + f"sqlite:///{db_path}", + pool_size=3, + max_overflow=0, + pool_timeout=30, + connect_args={"timeout": 30} + ) # Use a single event listener for both extension loading and WAL mode setting event.listen(engine, "connect", load_extension) @@ -339,142 +340,3 @@ def init_default_libraries(session, default_plugins): session.add(library_plugin) session.commit() - - -async def update_or_insert_entities_vec(session, target_id, embedding): - try: - session.execute( - text("DELETE FROM entities_vec WHERE rowid = :id"), - {"id": target_id} - ) - - session.execute( - text( - """ - INSERT INTO entities_vec (rowid, embedding) - VALUES (:id, :embedding) - """ - ), - { - "id": target_id, - "embedding": serialize_float32(embedding), - }, - ) - session.commit() - except Exception as e: - print(f"Error updating entities_vec: {e}") - session.rollback() - - -def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata): - try: - session.execute( - text( - """ - INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata) - VALUES(:id, :filepath, :tags, :metadata) - """ - ), - { - "id": target_id, - "filepath": filepath, - "tags": tags, - "metadata": metadata, - }, - ) - session.commit() - except Exception as e: - print(f"Error updating entities_fts: {e}") - session.rollback() - - -async def update_fts_and_vec(mapper, connection, entity: EntityModel): - session = Session(bind=connection) - - # Prepare FTS data - tags = ", ".join([tag.name for tag in entity.tags]) - - # Process metadata entries - def process_ocr_result(value, max_length=4096): - try: - ocr_data = json.loads(value) - if isinstance(ocr_data, list) and all( - isinstance(item, dict) - and "dt_boxes" in item - and "rec_txt" in item - and "score" in item - for item in ocr_data - ): - return " ".join(item["rec_txt"] for item in ocr_data[:max_length]) - else: - return json.dumps(ocr_data, indent=2) - except json.JSONDecodeError: - return value - - fts_metadata = "\n".join( - [ - f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}" - for entry in entity.metadata_entries - ] - ) - - # Update FTS table - update_or_insert_entities_fts( - session, entity.id, entity.filepath, tags, fts_metadata - ) - - # Prepare vector data - metadata_text = "\n".join( - [ - f"{entry.key}: {entry.value}" - for entry in entity.metadata_entries - if entry.key != "ocr_result" - ] - ) - - # Add ocr_result at the end of metadata_text using process_ocr_result - ocr_result = next( - (entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"), - "", - ) - processed_ocr_result = process_ocr_result(ocr_result, max_length=128) - metadata_text += f"\nocr_result: {processed_ocr_result}" - - # Use the new get_embeddings function - embeddings = await get_embeddings([metadata_text]) - if not embeddings: - embedding = [] - else: - embedding = embeddings[0] - - # Update vector table - if embedding: - await update_or_insert_entities_vec(session, entity.id, embedding) - - -def delete_fts_and_vec(mapper, connection, entity: EntityModel): - connection.execute( - text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity.id} - ) - connection.execute( - text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id} - ) - - -def run_async(coro): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete(coro) - - -def update_fts_and_vec_sync(mapper, connection, entity: EntityModel): - def run_in_thread(): - run_async(update_fts_and_vec(mapper, connection, entity)) - - thread = threading.Thread(target=run_in_thread) - thread.start() - thread.join() - -# Add event listeners for EntityModel -event.listen(EntityModel, "after_insert", update_fts_and_vec_sync) -event.listen(EntityModel, "after_update", update_fts_and_vec_sync) diff --git a/memos/schemas.py b/memos/schemas.py index d544d0f..9a3be34 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -276,3 +276,7 @@ class SearchResult(BaseModel): class EntityContext(BaseModel): prev: List[Entity] next: List[Entity] + + +class BatchIndexRequest(BaseModel): + entity_ids: List[int] diff --git a/memos/server.py b/memos/server.py index 5fbdd29..79c9047 100644 --- a/memos/server.py +++ b/memos/server.py @@ -43,6 +43,7 @@ from .schemas import ( SearchHit, RequestParams, EntityContext, + BatchIndexRequest, ) from .read_metadata import read_metadata from .logging_config import LOGGING_CONFIG @@ -239,6 +240,7 @@ async def new_entity( db: Session = Depends(get_db), plugins: Annotated[List[int] | None, Query()] = None, trigger_webhooks_flag: bool = True, + update_index: bool = False, ): library = crud.get_library_by_id(library_id, db) if library is None: @@ -249,6 +251,10 @@ async def new_entity( entity = crud.create_entity(library_id, new_entity, db) if trigger_webhooks_flag: await trigger_webhooks(library, entity, request, plugins) + + if update_index: + crud.update_entity_index(entity, db) + return entity @@ -346,6 +352,7 @@ async def update_entity( db: Session = Depends(get_db), trigger_webhooks_flag: bool = False, plugins: Annotated[List[int] | None, Query()] = None, + update_index: bool = False, ): entity = crud.find_entity_by_id(entity_id, db) if entity is None: @@ -364,6 +371,10 @@ async def update_entity( status_code=status.HTTP_404_NOT_FOUND, detail="Library not found" ) await trigger_webhooks(library, entity, request, plugins) + + if update_index: + crud.update_entity_index(entity, db) + return entity @@ -382,6 +393,46 @@ def update_entity_last_scan_at(entity_id: int, db: Session = Depends(get_db)): status_code=status.HTTP_404_NOT_FOUND, detail="Entity not found", ) + + +@app.post( + "/entities/{entity_id}/index", + status_code=status.HTTP_204_NO_CONTENT, + tags=["entity"], +) +def update_index(entity_id: int, db: Session = Depends(get_db)): + """ + Update the FTS and vector indexes for an entity. + """ + entity = crud.get_entity_by_id(entity_id, db) + if entity is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Entity not found", + ) + + crud.update_entity_index(entity, db) + + +@app.post( + "/entities/batch-index", + status_code=status.HTTP_204_NO_CONTENT, + tags=["entity"], +) +async def batch_update_index( + request: BatchIndexRequest, + db: Session = Depends(get_db) +): + """ + Batch update the FTS and vector indexes for multiple entities. + """ + try: + crud.batch_update_entity_indices(request.entity_ids, db) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e) + ) @app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"]) @@ -614,12 +665,12 @@ async def search_entities_v2( try: if q.strip() == "": # Use list_entities when q is empty - entities = await crud.list_entities( + entities = crud.list_entities( db=db, limit=limit, library_ids=library_ids, start=start, end=end ) else: # Use hybrid_search when q is not empty - entities = await crud.hybrid_search( + entities = crud.hybrid_search( query=q, db=db, limit=limit, diff --git a/memos/test_server.py b/memos/test_server.py index 4c9c9b3..ca07aca 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -85,6 +85,10 @@ def setup_library_with_entity(client): assert entity_response.status_code == 200 entity_id = entity_response.json()["id"] + # Update the entity's index + index_response = client.post(f"/entities/{entity_id}/index") + assert index_response.status_code == 204 + return library_id, folder_id, entity_id