feat(server): add search and file api

This commit is contained in:
arkohut 2024-07-02 00:27:19 +08:00
parent 973c5e5006
commit facf05117b
3 changed files with 135 additions and 9 deletions

View File

@ -1,7 +1,13 @@
import json import json
from typing import List from typing import List
from .schemas import MetadataType, EntityMetadata, EntityIndexItem, MetadataIndexItem from .schemas import (
MetadataType,
EntityMetadata,
EntityIndexItem,
MetadataIndexItem,
EntitySearchResult,
)
def convert_metadata_value(metadata: EntityMetadata): def convert_metadata_value(metadata: EntityMetadata):
@ -26,7 +32,9 @@ def upsert(client, entity):
file_last_modified_at=int(entity.file_last_modified_at.timestamp()), file_last_modified_at=int(entity.file_last_modified_at.timestamp()),
file_type=entity.file_type, file_type=entity.file_type,
file_type_group=entity.file_type_group, file_type_group=entity.file_type_group,
last_scan_at=int(entity.last_scan_at.timestamp()) if entity.last_scan_at else None, last_scan_at=(
int(entity.last_scan_at.timestamp()) if entity.last_scan_at else None
),
library_id=entity.library_id, library_id=entity.library_id,
folder_id=entity.folder_id, folder_id=entity.folder_id,
tags=[tag.name for tag in entity.tags], tags=[tag.name for tag in entity.tags],
@ -68,7 +76,9 @@ def remove_entity_by_id(client, entity_id):
) )
def list_all_entities(client, library_id: int, folder_id: int, limit=100, offset=0) -> List[EntityIndexItem]: def list_all_entities(
client, library_id: int, folder_id: int, limit=100, offset=0
) -> List[EntityIndexItem]:
try: try:
response = client.collections["entities"].documents.search( response = client.collections["entities"].documents.search(
{ {
@ -94,15 +104,73 @@ def list_all_entities(client, library_id: int, folder_id: int, limit=100, offset
tags=hit["document"]["tags"], tags=hit["document"]["tags"],
metadata_entries=[ metadata_entries=[
MetadataIndexItem( MetadataIndexItem(
key=entry["key"], key=entry["key"], value=entry["value"], source=entry["source"]
value=entry["value"], )
source=entry["source"] for entry in hit["document"]["metadata_entries"]
) for entry in hit["document"]["metadata_entries"]
], ],
metadata_text=hit["document"]["metadata_text"] metadata_text=hit["document"]["metadata_text"],
) for hit in response["hits"] )
for hit in response["hits"]
] ]
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"Failed to list entities for library {library_id} and folder {folder_id}: {str(e)}", f"Failed to list entities for library {library_id} and folder {folder_id}: {str(e)}",
) )
def search_entities(
client,
q: str,
library_id: int = None,
folder_id: int = None,
limit: int = 10,
offset: int = 0,
) -> List[EntitySearchResult]:
try:
filter_by = []
if library_id is not None:
filter_by.append(f"library_id:={library_id}")
if folder_id is not None:
filter_by.append(f"folder_id:={folder_id}")
filter_by_str = " && ".join(filter_by) if filter_by else ""
search_parameters = {
"q": q,
"query_by": "tags,metadata_entries,filepath,filename,embedding",
"filter_by": f"{filter_by_str} && file_type_group:=image" if filter_by_str else "file_type_group:=image",
"per_page": limit,
"page": offset // limit + 1,
"exclude_fields": "embedding,metadata_text",
"sort_by": "_text_match:desc,_vector_distance:asc",
}
search_results = client.collections["entities"].documents.search(
search_parameters
)
return [
EntitySearchResult(
id=hit["document"]["id"],
filepath=hit["document"]["filepath"],
filename=hit["document"]["filename"],
size=hit["document"]["size"],
file_created_at=hit["document"]["file_created_at"],
file_last_modified_at=hit["document"]["file_last_modified_at"],
file_type=hit["document"]["file_type"],
file_type_group=hit["document"]["file_type_group"],
last_scan_at=hit["document"].get("last_scan_at"),
library_id=hit["document"]["library_id"],
folder_id=hit["document"]["folder_id"],
tags=hit["document"]["tags"],
metadata_entries=[
MetadataIndexItem(
key=entry["key"], value=entry["value"], source=entry["source"]
)
for entry in hit["document"]["metadata_entries"]
],
)
for hit in search_results["hits"]
]
except Exception as e:
raise Exception(
f"Failed to search entities: {str(e)}",
)

View File

@ -162,3 +162,19 @@ class EntityIndexItem(BaseModel):
tags: List[str] tags: List[str]
metadata_entries: List[MetadataIndexItem] metadata_entries: List[MetadataIndexItem]
metadata_text: str metadata_text: str
class EntitySearchResult(BaseModel):
id: str
filepath: str
filename: str
size: int
file_created_at: int = Field(..., description="Unix timestamp")
file_last_modified_at: int = Field(..., description="Unix timestamp")
file_type: str
file_type_group: str
last_scan_at: Optional[int] = Field(None, description="Unix timestamp")
library_id: int
folder_id: int
tags: List[str]
metadata_entries: List[MetadataIndexItem]

View File

@ -1,10 +1,13 @@
import httpx import httpx
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException, Depends, status, Query, Request from fastapi import FastAPI, HTTPException, Depends, status, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from typing import List, Annotated from typing import List, Annotated
from fastapi.responses import FileResponse
from pathlib import Path
import asyncio import asyncio
import json import json
@ -30,6 +33,7 @@ from .schemas import (
MetadataType, MetadataType,
EntityIndexItem, EntityIndexItem,
MetadataIndexItem, MetadataIndexItem,
EntitySearchResult,
) )
engine = create_engine(f"sqlite:///{get_database_path()}") engine = create_engine(f"sqlite:///{get_database_path()}")
@ -53,6 +57,15 @@ client = typesense.Client(
app = FastAPI() app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust this as needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_db(): def get_db():
db = SessionLocal() db = SessionLocal()
try: try:
@ -304,6 +317,7 @@ async def remove_entity_from_typesense(entity_id: int, db: Session = Depends(get
) )
return None return None
@app.get( @app.get(
"/libraries/{library_id}/folders/{folder_id}/index", "/libraries/{library_id}/folders/{folder_id}/index",
response_model=List[EntityIndexItem], response_model=List[EntityIndexItem],
@ -331,6 +345,24 @@ def list_entities_in_folder(
return indexing.list_all_entities(client, library_id, folder_id, limit, offset) return indexing.list_all_entities(client, library_id, folder_id, limit, offset)
@app.get("/search", response_model=List[EntitySearchResult], tags=["search"])
async def search_entities(
q: str,
library_id: int = None,
folder_id: int = None,
limit: Annotated[int, Query(ge=1, le=200)] = 10,
offset: int = 0,
db: Session = Depends(get_db),
):
try:
return indexing.search_entities(client, q, library_id, folder_id, limit, offset)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@app.patch("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"]) @app.patch("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"])
@app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"]) @app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"])
def patch_entity_tags( def patch_entity_tags(
@ -419,5 +451,15 @@ def add_library_plugin(
crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db) crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db)
@app.get("/files/{file_path:path}", tags=["files"])
async def get_file(file_path: str):
full_path = Path("/") / file_path.strip("/")
# Check if the file exists and is a file
if full_path.is_file():
return FileResponse(full_path)
else:
raise HTTPException(status_code=404, detail="File not found")
def run_server(): def run_server():
uvicorn.run("memos.server:app", host="0.0.0.0", port=8080, reload=True) uvicorn.run("memos.server:app", host="0.0.0.0", port=8080, reload=True)