From 107f7d06c2ebd80e7d85e718058a35f3fceca574 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 11 Jun 2024 14:24:17 +0800 Subject: [PATCH] refactor(entity): make entity a root resource --- memos/server.py | 48 +++++++++++++++++++++++++------------------- memos/test_server.py | 28 +++++++++----------------- 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/memos/server.py b/memos/server.py index 4a219ff..13ae9c4 100644 --- a/memos/server.py +++ b/memos/server.py @@ -105,7 +105,7 @@ async def trigger_webhooks(library, entity, request): if plugin.webhook_url: location = str( request.url_for( - "get_entity_by_id", library_id=library.id, entity_id=entity.id + "get_entity_by_id", entity_id=entity.id ) ) task = client.post( @@ -182,8 +182,18 @@ def get_entity_by_filepath( return entity +@app.get("/entities/{entity_id}", response_model=Entity) +def get_entity_by_id(entity_id: int, db: Session = Depends(get_db)): + entity = crud.get_entity_by_id(entity_id, db) + if entity is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Entity not found" + ) + return entity + + @app.get("/libraries/{library_id}/entities/{entity_id}", response_model=Entity) -def get_entity_by_id(library_id: int, entity_id: int, db: Session = Depends(get_db)): +def get_entity_by_id_in_library(library_id: int, entity_id: int, db: Session = Depends(get_db)): entity = crud.get_entity_by_id(entity_id, db) if entity is None or entity.library_id != library_id: raise HTTPException( @@ -192,9 +202,8 @@ def get_entity_by_id(library_id: int, entity_id: int, db: Session = Depends(get_ return entity -@app.put("/libraries/{library_id}/entities/{entity_id}", response_model=Entity) +@app.put("/entities/{entity_id}", response_model=Entity) async def update_entity( - library_id: int, entity_id: int, updated_entity: UpdateEntityParam, request: Request, @@ -202,37 +211,35 @@ async def update_entity( trigger_webhooks_flag: bool = False, ): entity = crud.get_entity_by_id(entity_id, db) - if entity is None or entity.library_id != library_id: + if entity is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Entity not found in the specified library", - ) - - library = crud.get_library_by_id(library_id, db) - if library is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Library not found" + detail="Entity not found", ) entity = crud.update_entity(entity_id, updated_entity, db) if trigger_webhooks_flag: + library = crud.get_library_by_id(entity.library_id, db) + if library is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Library not found" + ) await trigger_webhooks(library, entity, request) return entity -@app.patch("/libraries/{library_id}/entities/{entity_id}/tags", response_model=Entity) -@app.put("/libraries/{library_id}/entities/{entity_id}/tags", response_model=Entity) +@app.patch("/entities/{entity_id}/tags", response_model=Entity) +@app.put("/entities/{entity_id}/tags", response_model=Entity) def patch_entity_tags( - library_id: int, entity_id: int, update_tags: UpdateEntityTagsParam, db: Session = Depends(get_db) ): entity = crud.get_entity_by_id(entity_id, db) - if entity is None or entity.library_id != library_id: + if entity is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Entity not found in the specified library", + detail="Entity not found", ) # Use the CRUD function to update the tags @@ -240,18 +247,17 @@ def patch_entity_tags( return entity -@app.patch("/libraries/{library_id}/entities/{entity_id}/metadata", response_model=Entity) +@app.patch("/entities/{entity_id}/metadata", response_model=Entity) def patch_entity_metadata( - library_id: int, entity_id: int, update_metadata: UpdateEntityMetadataParam, db: Session = Depends(get_db) ): entity = crud.get_entity_by_id(entity_id, db) - if entity is None or entity.library_id != library_id: + if entity is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Entity not found in the specified library", + detail="Entity not found", ) # Use the CRUD function to update the metadata entries diff --git a/memos/test_server.py b/memos/test_server.py index a241875..508a9ee 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -199,7 +199,7 @@ def test_update_entity(client): file_type="text/markdown", ) update_response = client.put( - f"/libraries/{library_id}/entities/{entity_id}", + f"/entities/{entity_id}", json=updated_entity.model_dump(mode="json"), ) @@ -216,22 +216,12 @@ def test_update_entity(client): # Test for entity not found invalid_update_response = client.put( - f"/libraries/{library_id}/entities/9999", + f"/entities/9999", json=updated_entity.model_dump(mode="json"), ) assert invalid_update_response.status_code == 404 assert invalid_update_response.json() == { - "detail": "Entity not found in the specified library" - } - - # Test for library not found - invalid_update_response = client.put( - f"/libraries/9999/entities/{entity_id}", - json=updated_entity.model_dump(mode="json"), - ) - assert invalid_update_response.status_code == 404 - assert invalid_update_response.json() == { - "detail": "Entity not found in the specified library" + "detail": "Entity not found" } @@ -461,7 +451,7 @@ def test_update_entity_with_tags(client): # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint update_response = client.put( - f"/libraries/{library_id}/entities/{entity_id}", + f"/entities/{entity_id}", json=update_entity_param.model_dump(mode="json"), ) @@ -490,7 +480,7 @@ def test_add_metadata_entry_to_entity_success(client): # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint update_response = client.put( - f"/libraries/{library_id}/entities/{entity_id}", + f"/entities/{entity_id}", json=update_entity_param.model_dump(mode="json"), ) @@ -516,7 +506,7 @@ def test_update_entity_tags(client): # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint update_response = client.put( - f"/libraries/{library_id}/entities/{entity_id}", + f"/entities/{entity_id}", json=update_entity_param.model_dump(mode="json"), ) @@ -555,7 +545,7 @@ def test_patch_entity_metadata_entries(client): # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint patch_response = client.put( - f"/libraries/{library_id}/entities/{entity_id}", + f"/entities/{entity_id}", json=update_entity_param.model_dump(mode="json"), ) @@ -584,7 +574,7 @@ def test_patch_entity_metadata_entries(client): # Make a PATCH request to the /libraries/{library_id}/entities/{entity_id}/metadata endpoint update_response = client.patch( - f"/libraries/{library_id}/entities/{entity_id}/metadata", + f"/entities/{entity_id}/metadata", json=update_entity_param.model_dump(mode="json"), ) @@ -616,7 +606,7 @@ def test_patch_entity_metadata_entries(client): # Make a PATCH request to the /libraries/{library_id}/entities/{entity_id}/metadata endpoint update_response = client.patch( - f"/libraries/{library_id}/entities/{entity_id}/metadata", + f"/entities/{entity_id}/metadata", json=update_entity_param.model_dump(mode="json"), )