mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-07 03:35:24 +00:00
feat(server): add api to fetch entity context entities
This commit is contained in:
parent
2308f1e0e8
commit
1b27697b88
@ -596,3 +596,58 @@ async def list_entities(
|
|||||||
entities = query.order_by(EntityModel.file_created_at.desc()).limit(limit).all()
|
entities = query.order_by(EntityModel.file_created_at.desc()).limit(limit).all()
|
||||||
|
|
||||||
return [Entity(**entity.__dict__) for entity in entities]
|
return [Entity(**entity.__dict__) for entity in entities]
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_context(
|
||||||
|
db: Session, library_id: int, entity_id: int, prev: int = 0, next: int = 0
|
||||||
|
) -> Tuple[List[Entity], List[Entity]]:
|
||||||
|
"""
|
||||||
|
Get the context (previous and next entities) for a given entity.
|
||||||
|
Returns a tuple of (previous_entities, next_entities).
|
||||||
|
"""
|
||||||
|
# First get the target entity to get its timestamp
|
||||||
|
target_entity = (
|
||||||
|
db.query(EntityModel)
|
||||||
|
.filter(
|
||||||
|
EntityModel.id == entity_id,
|
||||||
|
EntityModel.library_id == library_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not target_entity:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Get previous entities
|
||||||
|
prev_entities = []
|
||||||
|
if prev > 0:
|
||||||
|
prev_entities = (
|
||||||
|
db.query(EntityModel)
|
||||||
|
.filter(
|
||||||
|
EntityModel.library_id == library_id,
|
||||||
|
EntityModel.file_created_at < target_entity.file_created_at
|
||||||
|
)
|
||||||
|
.order_by(EntityModel.file_created_at.desc())
|
||||||
|
.limit(prev)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
# Reverse the list to get chronological order and convert to Entity models
|
||||||
|
prev_entities = [Entity(**entity.__dict__) for entity in prev_entities][::-1]
|
||||||
|
|
||||||
|
# Get next entities
|
||||||
|
next_entities = []
|
||||||
|
if next > 0:
|
||||||
|
next_entities = (
|
||||||
|
db.query(EntityModel)
|
||||||
|
.filter(
|
||||||
|
EntityModel.library_id == library_id,
|
||||||
|
EntityModel.file_created_at > target_entity.file_created_at
|
||||||
|
)
|
||||||
|
.order_by(EntityModel.file_created_at.asc())
|
||||||
|
.limit(next)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
# Convert to Entity models
|
||||||
|
next_entities = [Entity(**entity.__dict__) for entity in next_entities]
|
||||||
|
|
||||||
|
return prev_entities, next_entities
|
||||||
|
@ -292,3 +292,8 @@ class SearchResult(BaseModel):
|
|||||||
request_params: RequestParams
|
request_params: RequestParams
|
||||||
search_cutoff: bool
|
search_cutoff: bool
|
||||||
search_time_ms: int
|
search_time_ms: int
|
||||||
|
|
||||||
|
|
||||||
|
class EntityContext(BaseModel):
|
||||||
|
prev: List[Entity]
|
||||||
|
next: List[Entity]
|
||||||
|
@ -46,6 +46,7 @@ from .schemas import (
|
|||||||
SearchResult,
|
SearchResult,
|
||||||
SearchHit,
|
SearchHit,
|
||||||
RequestParams,
|
RequestParams,
|
||||||
|
EntityContext,
|
||||||
)
|
)
|
||||||
from .read_metadata import read_metadata
|
from .read_metadata import read_metadata
|
||||||
from .logging_config import LOGGING_CONFIG
|
from .logging_config import LOGGING_CONFIG
|
||||||
@ -886,6 +887,54 @@ async def search_entities_v2(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(
|
||||||
|
"/libraries/{library_id}/entities/{entity_id}/context",
|
||||||
|
response_model=EntityContext,
|
||||||
|
tags=["entity"],
|
||||||
|
)
|
||||||
|
def get_entity_context(
|
||||||
|
library_id: int,
|
||||||
|
entity_id: int,
|
||||||
|
prev: Annotated[int | None, Query(ge=0, le=100)] = None,
|
||||||
|
next: Annotated[int | None, Query(ge=0, le=100)] = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the context (previous and next entities) for a given entity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
library_id: The ID of the library
|
||||||
|
entity_id: The ID of the target entity
|
||||||
|
prev: Number of previous entities to fetch (optional)
|
||||||
|
next: Number of next entities to fetch (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EntityContext object containing prev and next lists of entities
|
||||||
|
"""
|
||||||
|
# If both prev and next are None, return empty lists
|
||||||
|
if prev is None and next is None:
|
||||||
|
return EntityContext(prev=[], next=[])
|
||||||
|
|
||||||
|
# Convert None to 0 for the crud function
|
||||||
|
prev_count = prev if prev is not None else 0
|
||||||
|
next_count = next if next is not None else 0
|
||||||
|
|
||||||
|
# Get the context entities
|
||||||
|
prev_entities, next_entities = crud.get_entity_context(
|
||||||
|
db=db,
|
||||||
|
library_id=library_id,
|
||||||
|
entity_id=entity_id,
|
||||||
|
prev=prev_count,
|
||||||
|
next=next_count
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the context object
|
||||||
|
return EntityContext(
|
||||||
|
prev=prev_entities,
|
||||||
|
next=next_entities
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
logging.info("Database path: %s", get_database_path())
|
logging.info("Database path: %s", get_database_path())
|
||||||
if settings.typesense.enabled:
|
if settings.typesense.enabled:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user