diff --git a/memos/crud.py b/memos/crud.py index 31f512c..f10aeb2 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -102,11 +102,11 @@ def get_entities_of_folder( return [], 0 query = db.query(EntityModel).filter(EntityModel.folder_id == folder_id) - + total_count = query.count() - + entities = query.limit(limit).offset(offset).all() - + return entities, total_count @@ -114,6 +114,10 @@ def get_entity_by_filepath(filepath: str, db: Session) -> Entity | None: return db.query(EntityModel).filter(EntityModel.filepath == filepath).first() +def get_entities_by_filepaths(filepaths: List[str], db: Session) -> List[Entity]: + return db.query(EntityModel).filter(EntityModel.filepath.in_(filepaths)).all() + + def remove_entity(entity_id: int, db: Session): entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first() if entity: diff --git a/memos/server.py b/memos/server.py index 8a6cc68..de26731 100644 --- a/memos/server.py +++ b/memos/server.py @@ -43,8 +43,7 @@ from .logging_config import LOGGING_CONFIG # Configure logging to include datetime logging.basicConfig( - format='%(asctime)s - %(levelname)s - %(message)s', - level=logging.INFO + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) engine = create_engine(f"sqlite:///{get_database_path()}") @@ -238,10 +237,11 @@ def list_entities_in_folder( detail="Folder not found in the specified library", ) - entities, total_count = crud.get_entities_of_folder(library_id, folder_id, db, limit, offset) + entities, total_count = crud.get_entities_of_folder( + library_id, folder_id, db, limit, offset + ) return JSONResponse( - content=jsonable_encoder(entities), - headers={"X-Total-Count": str(total_count)} + content=jsonable_encoder(entities), headers={"X-Total-Count": str(total_count)} ) @@ -261,6 +261,18 @@ def get_entity_by_filepath( return entity +@app.post( + "/libraries/{library_id}/entities/by-filepaths", + response_model=List[Entity], + tags=["entity"], +) +def get_entities_by_filepaths( + library_id: int, filepaths: List[str], db: Session = Depends(get_db) +): + entities = crud.get_entities_by_filepaths(filepaths, db) + return [entity for entity in entities if entity.library_id == library_id] + + @app.get("/entities/{entity_id}", response_model=Entity, tags=["entity"]) def get_entity_by_id(entity_id: int, db: Session = Depends(get_db)): entity = crud.get_entity_by_id(entity_id, db) @@ -539,4 +551,10 @@ def run_server(): print( f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}" ) - uvicorn.run("memos.server:app", host="0.0.0.0", port=8080, reload=False, log_config=LOGGING_CONFIG) + uvicorn.run( + "memos.server:app", + host="0.0.0.0", + port=8080, + reload=False, + log_config=LOGGING_CONFIG, + )