From 7512dacd52fe1499755cfbde51d4a817c3e73d90 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Sun, 2 Jun 2024 17:30:56 +0800 Subject: [PATCH] feat(entity): create and update entity by id --- memos/crud.py | 17 ++++++++- memos/schemas.py | 7 ++++ memos/server.py | 43 ++++++++++++++------- memos/test_server.py | 89 +++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 139 insertions(+), 17 deletions(-) diff --git a/memos/crud.py b/memos/crud.py index a3e052e..afc2a72 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -1,11 +1,11 @@ from typing import List from sqlalchemy.orm import Session -from .schemas import Library, NewLibraryParam, Folder, NewEntityParam, Entity, Plugin, NewPluginParam +from .schemas import Library, NewLibraryParam, Folder, NewEntityParam, Entity, Plugin, NewPluginParam, UpdateEntityParam from .models import LibraryModel, FolderModel, EntityModel, EntityModel, PluginModel, LibraryPluginModel def get_library_by_id(library_id: int, db: Session) -> Library | None: - return db.query(Library).filter(Library.id == library_id).first() + return db.query(LibraryModel).filter(LibraryModel.id == library_id).first() def create_library(library: NewLibraryParam, db: Session) -> Library: @@ -55,3 +55,16 @@ def add_plugin_to_library(library_id: int, plugin_id: int, db: Session): db.add(library_plugin) db.commit() db.refresh(library_plugin) + + +def get_entity_by_id(entity_id: int, db: Session) -> Entity | None: + return db.query(EntityModel).filter(EntityModel.id == entity_id).first() + + +def update_entity(entity_id: int, updated_entity: UpdateEntityParam, db: Session) -> Entity: + db_entity = get_entity_by_id(entity_id, db) + for key, value in updated_entity.model_dump().items(): + setattr(db_entity, key, value) + db.commit() + db.refresh(db_entity) + return db_entity \ No newline at end of file diff --git a/memos/schemas.py b/memos/schemas.py index 53ed27b..4ca143b 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -34,6 +34,13 @@ class NewEntityParam(BaseModel): folder_id: int +class UpdateEntityParam(BaseModel): + size: int + file_created_at: datetime + file_last_modified_at: datetime + file_type: str + + class UpdateTagParam(BaseModel): description: str | None color: str | None diff --git a/memos/server.py b/memos/server.py index 82144ed..d55fb96 100644 --- a/memos/server.py +++ b/memos/server.py @@ -6,14 +6,7 @@ from sqlalchemy.orm import sessionmaker from typing import List from .config import get_database_path -from .crud import ( - get_library_by_id, - create_library, - create_entity, - create_plugin, - add_plugin_to_library, - get_libraries, -) +import memos.crud as crud from .schemas import ( Library, Folder, @@ -22,6 +15,7 @@ from .schemas import ( NewLibraryParam, NewFolderParam, NewEntityParam, + UpdateEntityParam, NewPluginParam, NewLibraryPluginParam, ) @@ -47,13 +41,13 @@ def root(): @app.post("/libraries", response_model=Library) def new_library(library_param: NewLibraryParam, db: Session = Depends(get_db)): - library = create_library(library_param, db) + library = crud.create_library(library_param, db) return library @app.get("/libraries", response_model=List[Library]) def list_libraries(db: Session = Depends(get_db)): - libraries = get_libraries(db) + libraries = crud.get_libraries(db) return libraries @@ -63,7 +57,7 @@ def new_folder( folder: NewFolderParam, db: Session = Depends(get_db), ): - library = get_library_by_id(library_id, db) + library = crud.get_library_by_id(library_id, db) if library is None: raise HTTPException(status_code=404, detail="Library not found") @@ -78,13 +72,34 @@ def new_folder( def new_entity( new_entity: NewEntityParam, library_id: int, db: Session = Depends(get_db) ): - entity = create_entity(library_id, new_entity, db) + library = crud.get_library_by_id(library_id, db) + if library is None: + raise HTTPException(status_code=404, detail="Library not found") + + entity = crud.create_entity(library_id, new_entity, db) + return entity + + +@app.put("/libraries/{library_id}/entities/{entity_id}", response_model=Entity) +def update_entity( + library_id: int, + entity_id: int, + updated_entity: UpdateEntityParam, + db: Session = Depends(get_db), +): + entity = crud.get_entity_by_id(entity_id, db) + if entity is None or entity.library_id != library_id: + raise HTTPException( + status_code=404, detail="Entity not found in the specified library" + ) + + entity = crud.update_entity(entity_id, updated_entity, db) return entity @app.post("/plugins", response_model=Plugin) def new_plugin(new_plugin: NewPluginParam, db: Session = Depends(get_db)): - plugin = create_plugin(new_plugin, db) + plugin = crud.create_plugin(new_plugin, db) return plugin @@ -92,7 +107,7 @@ def new_plugin(new_plugin: NewPluginParam, db: Session = Depends(get_db)): def add_library_plugin( library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db) ): - add_plugin_to_library(library_id, new_plugin.plugin_id, db) + crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db) def run_server(): diff --git a/memos/test_server.py b/memos/test_server.py index 8f7aab0..b7e821e 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -8,7 +8,7 @@ from pathlib import Path from memos.server import app, get_db -from memos.schemas import Library, NewLibraryParam +from memos.schemas import Library, NewLibraryParam, NewEntityParam, UpdateEntityParam from memos.models import Base @@ -83,3 +83,90 @@ def test_list_libraries(client): } ] assert response.json() == expected_data + +def test_new_entity(client): + # Setup data: Create a new library + new_library = NewLibraryParam(name="Library for Entity Test", folders=["/tmp"]) + library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) + library_id = library_response.json()["id"] + folder_id = library_response.json()["folders"][0]["id"] + + # Create a new entity + new_entity = NewEntityParam( + filename="test_entity.txt", + filepath="/tmp/test_entity.txt", + size=150, + file_created_at="2023-01-01T00:00:00", + file_last_modified_at="2023-01-01T00:00:00", + file_type="text/plain", + folder_id=folder_id + ) + entity_response = client.post(f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")) + + # Check that the response is successful + assert entity_response.status_code == 200 + + # Check the response data + entity_data = entity_response.json() + assert entity_data["filename"] == "test_entity.txt" + assert entity_data["filepath"] == "/tmp/test_entity.txt" + assert entity_data["size"] == 150 + assert entity_data["file_created_at"] == "2023-01-01T00:00:00" + assert entity_data["file_last_modified_at"] == "2023-01-01T00:00:00" + assert entity_data["file_type"] == "text/plain" + assert entity_data["folder_id"] == 1 + + # Test for library not found + invalid_entity_response = client.post("/libraries/9999/entities", json=new_entity.model_dump(mode="json")) + assert invalid_entity_response.status_code == 404 + assert invalid_entity_response.json() == {"detail": "Library not found"} + + + +def test_update_entity(client): + # Setup data: Create a new library and entity + new_library = NewLibraryParam(name="Library for Update Test", folders=["/tmp"]) + library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) + library_id = library_response.json()["id"] + + new_entity = NewEntityParam( + filename="test.txt", + filepath="/tmp/test.txt", + size=100, + file_created_at="2023-01-01T00:00:00", + file_last_modified_at="2023-01-01T00:00:00", + file_type="text/plain", + folder_id=1 + ) + entity_response = client.post(f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")) + entity_id = entity_response.json()["id"] + + # Update the entity + updated_entity = UpdateEntityParam( + size=200, + file_created_at="2023-01-02T00:00:00", + file_last_modified_at="2023-01-02T00:00:00", + file_type="text/markdown" + ) + update_response = client.put(f"/libraries/{library_id}/entities/{entity_id}", json=updated_entity.model_dump(mode="json")) + + # Check that the response is successful + assert update_response.status_code == 200 + + # Check the response data + updated_data = update_response.json() + assert updated_data["id"] == entity_id + assert updated_data["size"] == 200 + assert updated_data["file_created_at"] == "2023-01-02T00:00:00" + assert updated_data["file_last_modified_at"] == "2023-01-02T00:00:00" + assert updated_data["file_type"] == "text/markdown" + + # Test for entity not found + invalid_update_response = client.put(f"/libraries/{library_id}/entities/9999", json=updated_entity.model_dump(mode="json")) + assert invalid_update_response.status_code == 404 + assert invalid_update_response.json() == {"detail": "Entity not found in the specified library"} + + # Test for library not found + invalid_update_response = client.put(f"/libraries/9999/entities/{entity_id}", json=updated_entity.model_dump(mode="json")) + assert invalid_update_response.status_code == 404 + assert invalid_update_response.json() == {"detail": "Entity not found in the specified library"}