mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(server): add api to fetch entity context entities
This commit is contained in:
parent
2308f1e0e8
commit
1b27697b88
@ -596,3 +596,58 @@ async def list_entities(
|
||||
entities = query.order_by(EntityModel.file_created_at.desc()).limit(limit).all()
|
||||
|
||||
return [Entity(**entity.__dict__) for entity in entities]
|
||||
|
||||
|
||||
def get_entity_context(
|
||||
db: Session, library_id: int, entity_id: int, prev: int = 0, next: int = 0
|
||||
) -> Tuple[List[Entity], List[Entity]]:
|
||||
"""
|
||||
Get the context (previous and next entities) for a given entity.
|
||||
Returns a tuple of (previous_entities, next_entities).
|
||||
"""
|
||||
# First get the target entity to get its timestamp
|
||||
target_entity = (
|
||||
db.query(EntityModel)
|
||||
.filter(
|
||||
EntityModel.id == entity_id,
|
||||
EntityModel.library_id == library_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not target_entity:
|
||||
return [], []
|
||||
|
||||
# Get previous entities
|
||||
prev_entities = []
|
||||
if prev > 0:
|
||||
prev_entities = (
|
||||
db.query(EntityModel)
|
||||
.filter(
|
||||
EntityModel.library_id == library_id,
|
||||
EntityModel.file_created_at < target_entity.file_created_at
|
||||
)
|
||||
.order_by(EntityModel.file_created_at.desc())
|
||||
.limit(prev)
|
||||
.all()
|
||||
)
|
||||
# Reverse the list to get chronological order and convert to Entity models
|
||||
prev_entities = [Entity(**entity.__dict__) for entity in prev_entities][::-1]
|
||||
|
||||
# Get next entities
|
||||
next_entities = []
|
||||
if next > 0:
|
||||
next_entities = (
|
||||
db.query(EntityModel)
|
||||
.filter(
|
||||
EntityModel.library_id == library_id,
|
||||
EntityModel.file_created_at > target_entity.file_created_at
|
||||
)
|
||||
.order_by(EntityModel.file_created_at.asc())
|
||||
.limit(next)
|
||||
.all()
|
||||
)
|
||||
# Convert to Entity models
|
||||
next_entities = [Entity(**entity.__dict__) for entity in next_entities]
|
||||
|
||||
return prev_entities, next_entities
|
||||
|
@ -292,3 +292,8 @@ class SearchResult(BaseModel):
|
||||
request_params: RequestParams
|
||||
search_cutoff: bool
|
||||
search_time_ms: int
|
||||
|
||||
|
||||
class EntityContext(BaseModel):
|
||||
prev: List[Entity]
|
||||
next: List[Entity]
|
||||
|
@ -46,6 +46,7 @@ from .schemas import (
|
||||
SearchResult,
|
||||
SearchHit,
|
||||
RequestParams,
|
||||
EntityContext,
|
||||
)
|
||||
from .read_metadata import read_metadata
|
||||
from .logging_config import LOGGING_CONFIG
|
||||
@ -886,6 +887,54 @@ async def search_entities_v2(
|
||||
)
|
||||
|
||||
|
||||
@app.get(
|
||||
"/libraries/{library_id}/entities/{entity_id}/context",
|
||||
response_model=EntityContext,
|
||||
tags=["entity"],
|
||||
)
|
||||
def get_entity_context(
|
||||
library_id: int,
|
||||
entity_id: int,
|
||||
prev: Annotated[int | None, Query(ge=0, le=100)] = None,
|
||||
next: Annotated[int | None, Query(ge=0, le=100)] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get the context (previous and next entities) for a given entity.
|
||||
|
||||
Args:
|
||||
library_id: The ID of the library
|
||||
entity_id: The ID of the target entity
|
||||
prev: Number of previous entities to fetch (optional)
|
||||
next: Number of next entities to fetch (optional)
|
||||
|
||||
Returns:
|
||||
EntityContext object containing prev and next lists of entities
|
||||
"""
|
||||
# If both prev and next are None, return empty lists
|
||||
if prev is None and next is None:
|
||||
return EntityContext(prev=[], next=[])
|
||||
|
||||
# Convert None to 0 for the crud function
|
||||
prev_count = prev if prev is not None else 0
|
||||
next_count = next if next is not None else 0
|
||||
|
||||
# Get the context entities
|
||||
prev_entities, next_entities = crud.get_entity_context(
|
||||
db=db,
|
||||
library_id=library_id,
|
||||
entity_id=entity_id,
|
||||
prev=prev_count,
|
||||
next=next_count
|
||||
)
|
||||
|
||||
# Return the context object
|
||||
return EntityContext(
|
||||
prev=prev_entities,
|
||||
next=next_entities
|
||||
)
|
||||
|
||||
|
||||
def run_server():
|
||||
logging.info("Database path: %s", get_database_path())
|
||||
if settings.typesense.enabled:
|
||||
|
Loading…
x
Reference in New Issue
Block a user