diff --git a/memos/crud.py b/memos/crud.py index afc2a72..68dd6d2 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -42,6 +42,14 @@ def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entit return db_entity +def get_entity_by_id(entity_id: int, db: Session) -> Entity | None: + return db.query(EntityModel).filter(EntityModel.id == entity_id).first() + + +def get_entity_by_filepath(filepath: str, db: Session) -> Entity | None: + return db.query(EntityModel).filter(EntityModel.filepath == filepath).first() + + def create_plugin(newPlugin: NewPluginParam, db: Session) -> Plugin: db_plugin = PluginModel(**newPlugin.model_dump(mode='json')) db.add(db_plugin) diff --git a/memos/server.py b/memos/server.py index d55fb96..800b574 100644 --- a/memos/server.py +++ b/memos/server.py @@ -80,6 +80,22 @@ def new_entity( return entity +@app.get("/libraries/{library_id}/entities/{entity_id}", response_model=Entity) +def get_entity_by_id(library_id: int, entity_id: int, db: Session = Depends(get_db)): + entity = crud.get_entity_by_id(entity_id, db) + if entity is None or entity.library_id != library_id: + raise HTTPException(status_code=404, detail="Entity not found") + return entity + + +@app.get("/libraries/{library_id}/entities", response_model=Entity) +def get_entity_by_filepath(library_id: int, filepath: str, db: Session = Depends(get_db)): + entity = crud.get_entity_by_filepath(filepath, db) + if entity is None or entity.library_id != library_id: + raise HTTPException(status_code=404, detail="Entity not found") + return entity + + @app.put("/libraries/{library_id}/entities/{entity_id}", response_model=Entity) def update_entity( library_id: int, diff --git a/memos/test_server.py b/memos/test_server.py index b7e821e..0b0e8f8 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -170,3 +170,47 @@ def test_update_entity(client): invalid_update_response = client.put(f"/libraries/9999/entities/{entity_id}", json=updated_entity.model_dump(mode="json")) assert invalid_update_response.status_code == 404 assert invalid_update_response.json() == {"detail": "Entity not found in the specified library"} + + +# Test for getting an entity by filepath +def test_get_entity_by_filepath(client): + # Setup data: Create a new library and entity + new_library = NewLibraryParam(name="Library for Get Entity Test", folders=["/tmp"]) + library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) + library_id = library_response.json()["id"] + + new_entity = NewEntityParam( + filename="test_get.txt", + filepath="/tmp/test_get.txt", + size=100, + file_created_at="2023-01-01T00:00:00", + file_last_modified_at="2023-01-01T00:00:00", + file_type="text/plain", + folder_id=1 + ) + entity_response = client.post(f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")) + entity_id = entity_response.json()["id"] + + # Get the entity by filepath + get_response = client.get(f"/libraries/{library_id}/entities", params={"filepath": new_entity.filepath}) + + # Check that the response is successful + assert get_response.status_code == 200 + + # Check the response data + entity_data = get_response.json() + assert entity_data["id"] == entity_id + assert entity_data["filepath"] == new_entity.filepath + assert entity_data["filename"] == new_entity.filename + assert entity_data["size"] == new_entity.size + assert entity_data["file_type"] == new_entity.file_type + + # Test for entity not found + invalid_get_response = client.get(f"/libraries/{library_id}/entities", params={"filepath": "nonexistent.txt"}) + assert invalid_get_response.status_code == 404 + assert invalid_get_response.json() == {"detail": "Entity not found"} + + # Test for library not found + invalid_get_response = client.get(f"/libraries/9999/entities", params={"filepath": new_entity.filepath}) + assert invalid_get_response.status_code == 404 + assert invalid_get_response.json() == {"detail": "Entity not found"}