From facf05117bb62c9d75c13ff3a713e4dde656eabb Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 2 Jul 2024 00:27:19 +0800 Subject: [PATCH] feat(server): add search and file api --- memos/indexing.py | 86 ++++++++++++++++++++++++++++++++++++++++++----- memos/schemas.py | 16 +++++++++ memos/server.py | 42 +++++++++++++++++++++++ 3 files changed, 135 insertions(+), 9 deletions(-) diff --git a/memos/indexing.py b/memos/indexing.py index c068587..8023d8d 100644 --- a/memos/indexing.py +++ b/memos/indexing.py @@ -1,7 +1,13 @@ import json from typing import List -from .schemas import MetadataType, EntityMetadata, EntityIndexItem, MetadataIndexItem +from .schemas import ( + MetadataType, + EntityMetadata, + EntityIndexItem, + MetadataIndexItem, + EntitySearchResult, +) def convert_metadata_value(metadata: EntityMetadata): @@ -26,7 +32,9 @@ def upsert(client, entity): file_last_modified_at=int(entity.file_last_modified_at.timestamp()), file_type=entity.file_type, file_type_group=entity.file_type_group, - last_scan_at=int(entity.last_scan_at.timestamp()) if entity.last_scan_at else None, + last_scan_at=( + int(entity.last_scan_at.timestamp()) if entity.last_scan_at else None + ), library_id=entity.library_id, folder_id=entity.folder_id, tags=[tag.name for tag in entity.tags], @@ -68,7 +76,9 @@ def remove_entity_by_id(client, entity_id): ) -def list_all_entities(client, library_id: int, folder_id: int, limit=100, offset=0) -> List[EntityIndexItem]: +def list_all_entities( + client, library_id: int, folder_id: int, limit=100, offset=0 +) -> List[EntityIndexItem]: try: response = client.collections["entities"].documents.search( { @@ -94,15 +104,73 @@ def list_all_entities(client, library_id: int, folder_id: int, limit=100, offset tags=hit["document"]["tags"], metadata_entries=[ MetadataIndexItem( - key=entry["key"], - value=entry["value"], - source=entry["source"] - ) for entry in hit["document"]["metadata_entries"] + key=entry["key"], value=entry["value"], source=entry["source"] + ) + for entry in hit["document"]["metadata_entries"] ], - metadata_text=hit["document"]["metadata_text"] - ) for hit in response["hits"] + metadata_text=hit["document"]["metadata_text"], + ) + for hit in response["hits"] ] except Exception as e: raise Exception( f"Failed to list entities for library {library_id} and folder {folder_id}: {str(e)}", ) + + +def search_entities( + client, + q: str, + library_id: int = None, + folder_id: int = None, + limit: int = 10, + offset: int = 0, +) -> List[EntitySearchResult]: + try: + filter_by = [] + if library_id is not None: + filter_by.append(f"library_id:={library_id}") + if folder_id is not None: + filter_by.append(f"folder_id:={folder_id}") + + filter_by_str = " && ".join(filter_by) if filter_by else "" + + search_parameters = { + "q": q, + "query_by": "tags,metadata_entries,filepath,filename,embedding", + "filter_by": f"{filter_by_str} && file_type_group:=image" if filter_by_str else "file_type_group:=image", + "per_page": limit, + "page": offset // limit + 1, + "exclude_fields": "embedding,metadata_text", + "sort_by": "_text_match:desc,_vector_distance:asc", + } + search_results = client.collections["entities"].documents.search( + search_parameters + ) + return [ + EntitySearchResult( + id=hit["document"]["id"], + filepath=hit["document"]["filepath"], + filename=hit["document"]["filename"], + size=hit["document"]["size"], + file_created_at=hit["document"]["file_created_at"], + file_last_modified_at=hit["document"]["file_last_modified_at"], + file_type=hit["document"]["file_type"], + file_type_group=hit["document"]["file_type_group"], + last_scan_at=hit["document"].get("last_scan_at"), + library_id=hit["document"]["library_id"], + folder_id=hit["document"]["folder_id"], + tags=hit["document"]["tags"], + metadata_entries=[ + MetadataIndexItem( + key=entry["key"], value=entry["value"], source=entry["source"] + ) + for entry in hit["document"]["metadata_entries"] + ], + ) + for hit in search_results["hits"] + ] + except Exception as e: + raise Exception( + f"Failed to search entities: {str(e)}", + ) diff --git a/memos/schemas.py b/memos/schemas.py index ef8fe25..8440de3 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -162,3 +162,19 @@ class EntityIndexItem(BaseModel): tags: List[str] metadata_entries: List[MetadataIndexItem] metadata_text: str + + +class EntitySearchResult(BaseModel): + id: str + filepath: str + filename: str + size: int + file_created_at: int = Field(..., description="Unix timestamp") + file_last_modified_at: int = Field(..., description="Unix timestamp") + file_type: str + file_type_group: str + last_scan_at: Optional[int] = Field(None, description="Unix timestamp") + library_id: int + folder_id: int + tags: List[str] + metadata_entries: List[MetadataIndexItem] \ No newline at end of file diff --git a/memos/server.py b/memos/server.py index e7231c8..e441c5f 100644 --- a/memos/server.py +++ b/memos/server.py @@ -1,10 +1,13 @@ import httpx import uvicorn from fastapi import FastAPI, HTTPException, Depends, status, Query, Request +from fastapi.middleware.cors import CORSMiddleware from sqlalchemy.orm import Session from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from typing import List, Annotated +from fastapi.responses import FileResponse +from pathlib import Path import asyncio import json @@ -30,6 +33,7 @@ from .schemas import ( MetadataType, EntityIndexItem, MetadataIndexItem, + EntitySearchResult, ) engine = create_engine(f"sqlite:///{get_database_path()}") @@ -53,6 +57,15 @@ client = typesense.Client( app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Adjust this as needed + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + def get_db(): db = SessionLocal() try: @@ -304,6 +317,7 @@ async def remove_entity_from_typesense(entity_id: int, db: Session = Depends(get ) return None + @app.get( "/libraries/{library_id}/folders/{folder_id}/index", response_model=List[EntityIndexItem], @@ -331,6 +345,24 @@ def list_entities_in_folder( return indexing.list_all_entities(client, library_id, folder_id, limit, offset) +@app.get("/search", response_model=List[EntitySearchResult], tags=["search"]) +async def search_entities( + q: str, + library_id: int = None, + folder_id: int = None, + limit: Annotated[int, Query(ge=1, le=200)] = 10, + offset: int = 0, + db: Session = Depends(get_db), +): + try: + return indexing.search_entities(client, q, library_id, folder_id, limit, offset) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + @app.patch("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"]) @app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"]) def patch_entity_tags( @@ -419,5 +451,15 @@ def add_library_plugin( crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db) +@app.get("/files/{file_path:path}", tags=["files"]) +async def get_file(file_path: str): + full_path = Path("/") / file_path.strip("/") + # Check if the file exists and is a file + if full_path.is_file(): + return FileResponse(full_path) + else: + raise HTTPException(status_code=404, detail="File not found") + + def run_server(): uvicorn.run("memos.server:app", host="0.0.0.0", port=8080, reload=True)