diff --git a/memos/crud.py b/memos/crud.py index 9ff8d56..19d3fd3 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -596,3 +596,58 @@ async def list_entities( entities = query.order_by(EntityModel.file_created_at.desc()).limit(limit).all() return [Entity(**entity.__dict__) for entity in entities] + + +def get_entity_context( + db: Session, library_id: int, entity_id: int, prev: int = 0, next: int = 0 +) -> Tuple[List[Entity], List[Entity]]: + """ + Get the context (previous and next entities) for a given entity. + Returns a tuple of (previous_entities, next_entities). + """ + # First get the target entity to get its timestamp + target_entity = ( + db.query(EntityModel) + .filter( + EntityModel.id == entity_id, + EntityModel.library_id == library_id, + ) + .first() + ) + + if not target_entity: + return [], [] + + # Get previous entities + prev_entities = [] + if prev > 0: + prev_entities = ( + db.query(EntityModel) + .filter( + EntityModel.library_id == library_id, + EntityModel.file_created_at < target_entity.file_created_at + ) + .order_by(EntityModel.file_created_at.desc()) + .limit(prev) + .all() + ) + # Reverse the list to get chronological order and convert to Entity models + prev_entities = [Entity(**entity.__dict__) for entity in prev_entities][::-1] + + # Get next entities + next_entities = [] + if next > 0: + next_entities = ( + db.query(EntityModel) + .filter( + EntityModel.library_id == library_id, + EntityModel.file_created_at > target_entity.file_created_at + ) + .order_by(EntityModel.file_created_at.asc()) + .limit(next) + .all() + ) + # Convert to Entity models + next_entities = [Entity(**entity.__dict__) for entity in next_entities] + + return prev_entities, next_entities diff --git a/memos/schemas.py b/memos/schemas.py index c406733..401c76e 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -292,3 +292,8 @@ class SearchResult(BaseModel): request_params: RequestParams search_cutoff: bool search_time_ms: int + + +class EntityContext(BaseModel): + prev: List[Entity] + next: List[Entity] diff --git a/memos/server.py b/memos/server.py index a648362..52357d5 100644 --- a/memos/server.py +++ b/memos/server.py @@ -46,6 +46,7 @@ from .schemas import ( SearchResult, SearchHit, RequestParams, + EntityContext, ) from .read_metadata import read_metadata from .logging_config import LOGGING_CONFIG @@ -886,6 +887,54 @@ async def search_entities_v2( ) +@app.get( + "/libraries/{library_id}/entities/{entity_id}/context", + response_model=EntityContext, + tags=["entity"], +) +def get_entity_context( + library_id: int, + entity_id: int, + prev: Annotated[int | None, Query(ge=0, le=100)] = None, + next: Annotated[int | None, Query(ge=0, le=100)] = None, + db: Session = Depends(get_db), +): + """ + Get the context (previous and next entities) for a given entity. + + Args: + library_id: The ID of the library + entity_id: The ID of the target entity + prev: Number of previous entities to fetch (optional) + next: Number of next entities to fetch (optional) + + Returns: + EntityContext object containing prev and next lists of entities + """ + # If both prev and next are None, return empty lists + if prev is None and next is None: + return EntityContext(prev=[], next=[]) + + # Convert None to 0 for the crud function + prev_count = prev if prev is not None else 0 + next_count = next if next is not None else 0 + + # Get the context entities + prev_entities, next_entities = crud.get_entity_context( + db=db, + library_id=library_id, + entity_id=entity_id, + prev=prev_count, + next=next_count + ) + + # Return the context object + return EntityContext( + prev=prev_entities, + next=next_entities + ) + + def run_server(): logging.info("Database path: %s", get_database_path()) if settings.typesense.enabled: