feat(indexing): support facet

This commit is contained in:
arkohut 2024-08-21 00:12:31 +08:00
parent 2cf75bee7f
commit 11e447bcbb
3 changed files with 128 additions and 34 deletions

View File

@ -8,6 +8,12 @@ from .schemas import (
EntityIndexItem,
MetadataIndexItem,
EntitySearchResult,
SearchResult,
Facet,
SearchHit,
TextMatchInfo,
HybridSearchInfo,
RequestParams,
)
from .config import TYPESENSE_COLLECTION_NAME
@ -133,7 +139,9 @@ def upsert(client, entity):
# Sync the entity data to Typesense
try:
client.collections[TYPESENSE_COLLECTION_NAME].documents.upsert(entity_data.model_dump_json())
client.collections[TYPESENSE_COLLECTION_NAME].documents.upsert(
entity_data.model_dump_json()
)
except Exception as e:
raise Exception(
f"Failed to sync entity to Typesense: {str(e)}",
@ -199,11 +207,13 @@ def search_entities(
q: str,
library_ids: List[int] = None,
folder_ids: List[int] = None,
tags: List[str] = None,
created_dates: List[str] = None,
limit: int = 48,
offset: int = 0,
start: int = None,
end: int = None,
) -> List[EntitySearchResult]:
) -> SearchResult:
try:
filter_by = []
if library_ids:
@ -212,6 +222,10 @@ def search_entities(
filter_by.append(f"folder_id:[{','.join(map(str, folder_ids))}]")
if start is not None and end is not None:
filter_by.append(f"file_created_at:={start}..{end}")
if tags:
filter_by.append(f"tags:=[{','.join(tags)}]")
if created_dates:
filter_by.append(f"created_date:[{','.join(created_dates)}]")
filter_by_str = " && ".join(filter_by) if filter_by else ""
search_parameters = {
@ -227,36 +241,59 @@ def search_entities(
"offset": offset,
"exclude_fields": "metadata_text,embedding",
"sort_by": "_text_match:desc",
"facet_by": "created_date,created_month,created_year,tags",
}
search_results = client.collections[TYPESENSE_COLLECTION_NAME].documents.search(
search_parameters
)
return [
EntitySearchResult(
id=hit["document"]["id"],
filepath=hit["document"]["filepath"],
filename=hit["document"]["filename"],
size=hit["document"]["size"],
file_created_at=hit["document"]["file_created_at"],
file_last_modified_at=hit["document"]["file_last_modified_at"],
file_type=hit["document"]["file_type"],
file_type_group=hit["document"]["file_type_group"],
last_scan_at=hit["document"].get("last_scan_at"),
library_id=hit["document"]["library_id"],
folder_id=hit["document"]["folder_id"],
tags=hit["document"]["tags"],
metadata_entries=[
MetadataIndexItem(
key=entry["key"], value=entry["value"], source=entry["source"]
)
for entry in hit["document"]["metadata_entries"]
],
created_date=hit["document"]["created_date"],
created_month=hit["document"]["created_month"],
created_year=hit["document"]["created_year"],
hits = [
SearchHit(
document=EntitySearchResult(
id=hit["document"]["id"],
filepath=hit["document"]["filepath"],
filename=hit["document"]["filename"],
size=hit["document"]["size"],
file_created_at=hit["document"]["file_created_at"],
file_last_modified_at=hit["document"]["file_last_modified_at"],
file_type=hit["document"]["file_type"],
file_type_group=hit["document"]["file_type_group"],
last_scan_at=hit["document"].get("last_scan_at"),
library_id=hit["document"]["library_id"],
folder_id=hit["document"]["folder_id"],
tags=hit["document"]["tags"],
metadata_entries=[
MetadataIndexItem(
key=entry["key"],
value=entry["value"],
source=entry["source"],
)
for entry in hit["document"]["metadata_entries"]
],
created_date=hit["document"].get("created_date"),
created_month=hit["document"].get("created_month"),
created_year=hit["document"].get("created_year"),
),
highlight=hit.get("highlight", {}),
highlights=hit.get("highlights", []),
hybrid_search_info=HybridSearchInfo(**hit["hybrid_search_info"]),
text_match=hit["text_match"],
text_match_info=TextMatchInfo(**hit["text_match_info"]),
)
for hit in search_results["hits"]
]
return SearchResult(
facet_counts=[Facet(**facet) for facet in search_results["facet_counts"]],
found=search_results["found"],
hits=hits,
out_of=search_results["out_of"],
page=search_results["page"],
request_params=RequestParams(**search_results["request_params"]),
search_cutoff=search_results["search_cutoff"],
search_time_ms=search_results["search_time_ms"],
)
except Exception as e:
raise Exception(
f"Failed to search entities: {str(e)}",
@ -265,7 +302,9 @@ def search_entities(
def fetch_entity_by_id(client, id: str) -> EntityIndexItem:
try:
document = client.collections[TYPESENSE_COLLECTION_NAME].documents[id].retrieve()
document = (
client.collections[TYPESENSE_COLLECTION_NAME].documents[id].retrieve()
)
return EntitySearchResult(
id=document["id"],
filepath=document["filepath"],

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel, ConfigDict, DirectoryPath, HttpUrl, Field
from typing import List, Optional, Any
from typing import List, Optional, Any, Dict
from datetime import datetime
from enum import Enum
@ -204,4 +204,56 @@ class EntitySearchResult(BaseModel):
library_id: int
folder_id: int
tags: List[str]
metadata_entries: List[MetadataIndexItem]
metadata_entries: List[MetadataIndexItem]
facets: Optional[Dict[str, Any]] = None
class FacetCount(BaseModel):
count: int
highlighted: str
value: str
class FacetStats(BaseModel):
total_values: int
class Facet(BaseModel):
counts: List[FacetCount]
field_name: str
sampled: bool
stats: FacetStats
class TextMatchInfo(BaseModel):
best_field_score: str
best_field_weight: int
fields_matched: int
num_tokens_dropped: int
score: str
tokens_matched: int
typo_prefix_score: int
class HybridSearchInfo(BaseModel):
rank_fusion_score: float
class SearchHit(BaseModel):
document: EntitySearchResult
highlight: Dict[str, Any] = {}
highlights: List[Any] = []
hybrid_search_info: HybridSearchInfo
text_match: int
text_match_info: TextMatchInfo
class RequestParams(BaseModel):
collection_name: str
first_q: str
per_page: int
q: str
class SearchResult(BaseModel):
facet_counts: List[Facet]
found: int
hits: List[SearchHit]
out_of: int
page: int
request_params: RequestParams
search_cutoff: bool
search_time_ms: int

View File

@ -39,6 +39,7 @@ from .schemas import (
EntityIndexItem,
MetadataIndexItem,
EntitySearchResult,
SearchResult
)
# Import the logging configuration
@ -439,24 +440,26 @@ def list_entitiy_indices_in_folder(
return indexing.list_all_entities(client, library_id, folder_id, limit, offset)
@app.get("/search", response_model=List[EntitySearchResult], tags=["search"])
@app.get("/search", response_model=SearchResult, tags=["search"])
async def search_entities(
q: str,
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
tags: str = Query(None, description="Comma-separated list of tags"),
created_dates: str = Query(None, description="Comma-separated list of created dates in YYYY-MM-DD format"),
limit: Annotated[int, Query(ge=1, le=200)] = 48,
offset: int = 0,
start: int = None,
end: int = None,
db: Session = Depends(get_db),
):
library_ids = (
[int(id) for id in library_ids.split(",") if id] if library_ids else None
)
folder_ids = [int(id) for id in folder_ids.split(",") if id] if folder_ids else None
library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None
folder_ids = [int(id) for id in folder_ids.split(",")] if folder_ids else None
tags = [tag.strip() for tag in tags.split(",")] if tags else None
created_dates = [date.strip() for date in created_dates.split(",")] if created_dates else None
try:
return indexing.search_entities(
client, q, library_ids, folder_ids, limit, offset, start, end
client, q, library_ids, folder_ids, tags, created_dates, limit, offset, start, end
)
except Exception as e:
print(f"Error searching entities: {e}")