feat(server): add api to fetch entity context entities

This commit is contained in:
arkohut 2024-11-04 16:19:48 +08:00
parent 2308f1e0e8
commit 1b27697b88
3 changed files with 109 additions and 0 deletions

View File

@ -596,3 +596,58 @@ async def list_entities(
entities = query.order_by(EntityModel.file_created_at.desc()).limit(limit).all()
return [Entity(**entity.__dict__) for entity in entities]
def get_entity_context(
db: Session, library_id: int, entity_id: int, prev: int = 0, next: int = 0
) -> Tuple[List[Entity], List[Entity]]:
"""
Get the context (previous and next entities) for a given entity.
Returns a tuple of (previous_entities, next_entities).
"""
# First get the target entity to get its timestamp
target_entity = (
db.query(EntityModel)
.filter(
EntityModel.id == entity_id,
EntityModel.library_id == library_id,
)
.first()
)
if not target_entity:
return [], []
# Get previous entities
prev_entities = []
if prev > 0:
prev_entities = (
db.query(EntityModel)
.filter(
EntityModel.library_id == library_id,
EntityModel.file_created_at < target_entity.file_created_at
)
.order_by(EntityModel.file_created_at.desc())
.limit(prev)
.all()
)
# Reverse the list to get chronological order and convert to Entity models
prev_entities = [Entity(**entity.__dict__) for entity in prev_entities][::-1]
# Get next entities
next_entities = []
if next > 0:
next_entities = (
db.query(EntityModel)
.filter(
EntityModel.library_id == library_id,
EntityModel.file_created_at > target_entity.file_created_at
)
.order_by(EntityModel.file_created_at.asc())
.limit(next)
.all()
)
# Convert to Entity models
next_entities = [Entity(**entity.__dict__) for entity in next_entities]
return prev_entities, next_entities

View File

@ -292,3 +292,8 @@ class SearchResult(BaseModel):
request_params: RequestParams
search_cutoff: bool
search_time_ms: int
class EntityContext(BaseModel):
prev: List[Entity]
next: List[Entity]

View File

@ -46,6 +46,7 @@ from .schemas import (
SearchResult,
SearchHit,
RequestParams,
EntityContext,
)
from .read_metadata import read_metadata
from .logging_config import LOGGING_CONFIG
@ -886,6 +887,54 @@ async def search_entities_v2(
)
@app.get(
"/libraries/{library_id}/entities/{entity_id}/context",
response_model=EntityContext,
tags=["entity"],
)
def get_entity_context(
library_id: int,
entity_id: int,
prev: Annotated[int | None, Query(ge=0, le=100)] = None,
next: Annotated[int | None, Query(ge=0, le=100)] = None,
db: Session = Depends(get_db),
):
"""
Get the context (previous and next entities) for a given entity.
Args:
library_id: The ID of the library
entity_id: The ID of the target entity
prev: Number of previous entities to fetch (optional)
next: Number of next entities to fetch (optional)
Returns:
EntityContext object containing prev and next lists of entities
"""
# If both prev and next are None, return empty lists
if prev is None and next is None:
return EntityContext(prev=[], next=[])
# Convert None to 0 for the crud function
prev_count = prev if prev is not None else 0
next_count = next if next is not None else 0
# Get the context entities
prev_entities, next_entities = crud.get_entity_context(
db=db,
library_id=library_id,
entity_id=entity_id,
prev=prev_count,
next=next_count
)
# Return the context object
return EntityContext(
prev=prev_entities,
next=next_entities
)
def run_server():
logging.info("Database path: %s", get_database_path())
if settings.typesense.enabled: