diff --git a/memos/crud.py b/memos/crud.py index 68dd6d2..7b3154f 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -1,6 +1,6 @@ from typing import List from sqlalchemy.orm import Session -from .schemas import Library, NewLibraryParam, Folder, NewEntityParam, Entity, Plugin, NewPluginParam, UpdateEntityParam +from .schemas import Library, NewLibraryParam, Folder, NewEntityParam, Entity, Plugin, NewPluginParam, UpdateEntityParam, NewFolderParam from .models import LibraryModel, FolderModel, EntityModel, EntityModel, PluginModel, LibraryPluginModel @@ -31,6 +31,14 @@ def get_libraries(db: Session) -> List[Library]: return db.query(LibraryModel).all() +def add_folder(library_id: int, folder: NewFolderParam, db: Session) -> Folder: + db_folder = FolderModel(path=str(folder.path), library_id=library_id) + db.add(db_folder) + db.commit() + db.refresh(db_folder) + return Folder(id=db_folder.id, path=db_folder.path) + + def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity: db_entity = EntityModel( **entity.model_dump(), @@ -46,6 +54,15 @@ def get_entity_by_id(entity_id: int, db: Session) -> Entity | None: return db.query(EntityModel).filter(EntityModel.id == entity_id).first() +def get_entities_of_folder(library_id: int, folder_id: int, db: Session, limit: int = 10, offset: int = 0) -> List[Entity]: + folder = db.query(FolderModel).filter(FolderModel.id == folder_id, FolderModel.library_id == library_id).first() + if folder is None: + return [] + + entities = db.query(EntityModel).filter(EntityModel.folder_id == folder_id).limit(limit).offset(offset).all() + return entities + + def get_entity_by_filepath(filepath: str, db: Session) -> Entity | None: return db.query(EntityModel).filter(EntityModel.filepath == filepath).first() diff --git a/memos/server.py b/memos/server.py index 0785c64..4457af2 100644 --- a/memos/server.py +++ b/memos/server.py @@ -1,9 +1,9 @@ import uvicorn -from fastapi import FastAPI, HTTPException, Depends, status +from fastapi import FastAPI, HTTPException, Depends, status, Query from sqlalchemy.orm import Session from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from typing import List +from typing import List, Annotated from .config import get_database_path import memos.crud as crud @@ -69,11 +69,7 @@ def new_folder( if library is None: raise HTTPException(status_code=404, detail="Library not found") - db_folder = Folder(path=folder.path, library_id=library.id) - db.add(db_folder) - db.commit() - db.refresh(db_folder) - return db_folder + return crud.add_folder(library_id=library.id, folder=folder, db=db) @app.post("/libraries/{library_id}/entities", response_model=Entity) @@ -88,6 +84,28 @@ def new_entity( return entity +@app.get( + "/libraries/{library_id}/folders/{folder_id}/entities", response_model=List[Entity] +) +def list_entities_in_folder( + library_id: int, + folder_id: int, + limit: Annotated[int, Query(ge=1, le=200)] = 10, + offset: int = 0, + db: Session = Depends(get_db), +): + library = crud.get_library_by_id(library_id, db) + if library is None: + raise HTTPException(status_code=404, detail="Library not found") + + if folder_id not in [folder.id for folder in library.folders]: + raise HTTPException( + status_code=404, detail="Folder not found in the specified library" + ) + + return crud.get_entities_of_folder(library_id, folder_id, db, limit, offset) + + @app.get("/libraries/{library_id}/entities/by-filepath", response_model=Entity) def get_entity_by_filepath( library_id: int, filepath: str, db: Session = Depends(get_db) diff --git a/memos/test_server.py b/memos/test_server.py index 892bdba..f30d242 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -8,7 +8,13 @@ from pathlib import Path from memos.server import app, get_db -from memos.schemas import Library, NewLibraryParam, NewEntityParam, UpdateEntityParam +from memos.schemas import ( + Library, + NewLibraryParam, + NewEntityParam, + UpdateEntityParam, + NewFolderParam, +) from memos.models import Base @@ -84,10 +90,13 @@ def test_list_libraries(client): ] assert response.json() == expected_data + def test_new_entity(client): # Setup data: Create a new library new_library = NewLibraryParam(name="Library for Entity Test", folders=["/tmp"]) - library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) + library_response = client.post( + "/libraries", json=new_library.model_dump(mode="json") + ) library_id = library_response.json()["id"] folder_id = library_response.json()["folders"][0]["id"] @@ -99,9 +108,11 @@ def test_new_entity(client): file_created_at="2023-01-01T00:00:00", file_last_modified_at="2023-01-01T00:00:00", file_type="text/plain", - folder_id=folder_id + folder_id=folder_id, + ) + entity_response = client.post( + f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json") ) - entity_response = client.post(f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")) # Check that the response is successful assert entity_response.status_code == 200 @@ -117,16 +128,19 @@ def test_new_entity(client): assert entity_data["folder_id"] == 1 # Test for library not found - invalid_entity_response = client.post("/libraries/9999/entities", json=new_entity.model_dump(mode="json")) + invalid_entity_response = client.post( + "/libraries/9999/entities", json=new_entity.model_dump(mode="json") + ) assert invalid_entity_response.status_code == 404 assert invalid_entity_response.json() == {"detail": "Library not found"} - def test_update_entity(client): # Setup data: Create a new library and entity new_library = NewLibraryParam(name="Library for Update Test", folders=["/tmp"]) - library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) + library_response = client.post( + "/libraries", json=new_library.model_dump(mode="json") + ) library_id = library_response.json()["id"] new_entity = NewEntityParam( @@ -136,9 +150,11 @@ def test_update_entity(client): file_created_at="2023-01-01T00:00:00", file_last_modified_at="2023-01-01T00:00:00", file_type="text/plain", - folder_id=1 + folder_id=1, + ) + entity_response = client.post( + f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json") ) - entity_response = client.post(f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")) entity_id = entity_response.json()["id"] # Update the entity @@ -146,9 +162,12 @@ def test_update_entity(client): size=200, file_created_at="2023-01-02T00:00:00", file_last_modified_at="2023-01-02T00:00:00", - file_type="text/markdown" + file_type="text/markdown", + ) + update_response = client.put( + f"/libraries/{library_id}/entities/{entity_id}", + json=updated_entity.model_dump(mode="json"), ) - update_response = client.put(f"/libraries/{library_id}/entities/{entity_id}", json=updated_entity.model_dump(mode="json")) # Check that the response is successful assert update_response.status_code == 200 @@ -162,21 +181,33 @@ def test_update_entity(client): assert updated_data["file_type"] == "text/markdown" # Test for entity not found - invalid_update_response = client.put(f"/libraries/{library_id}/entities/9999", json=updated_entity.model_dump(mode="json")) + invalid_update_response = client.put( + f"/libraries/{library_id}/entities/9999", + json=updated_entity.model_dump(mode="json"), + ) assert invalid_update_response.status_code == 404 - assert invalid_update_response.json() == {"detail": "Entity not found in the specified library"} + assert invalid_update_response.json() == { + "detail": "Entity not found in the specified library" + } # Test for library not found - invalid_update_response = client.put(f"/libraries/9999/entities/{entity_id}", json=updated_entity.model_dump(mode="json")) + invalid_update_response = client.put( + f"/libraries/9999/entities/{entity_id}", + json=updated_entity.model_dump(mode="json"), + ) assert invalid_update_response.status_code == 404 - assert invalid_update_response.json() == {"detail": "Entity not found in the specified library"} + assert invalid_update_response.json() == { + "detail": "Entity not found in the specified library" + } # Test for getting an entity by filepath def test_get_entity_by_filepath(client): # Setup data: Create a new library and entity new_library = NewLibraryParam(name="Library for Get Entity Test", folders=["/tmp"]) - library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) + library_response = client.post( + "/libraries", json=new_library.model_dump(mode="json") + ) library_id = library_response.json()["id"] new_entity = NewEntityParam( @@ -186,13 +217,18 @@ def test_get_entity_by_filepath(client): file_created_at="2023-01-01T00:00:00", file_last_modified_at="2023-01-01T00:00:00", file_type="text/plain", - folder_id=1 + folder_id=1, + ) + entity_response = client.post( + f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json") ) - entity_response = client.post(f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")) entity_id = entity_response.json()["id"] - get_response = client.get(f"/libraries/{library_id}/entities/by-filepath", params={"filepath": new_entity.filepath}) - + get_response = client.get( + f"/libraries/{library_id}/entities/by-filepath", + params={"filepath": new_entity.filepath}, + ) + # Check that the response is successful assert get_response.status_code == 200 @@ -205,11 +241,76 @@ def test_get_entity_by_filepath(client): assert entity_data["file_type"] == new_entity.file_type # Test for entity not found - invalid_get_response = client.get(f"/libraries/{library_id}/entities/by-filepath", params={"filepath": "nonexistent.txt"}) + invalid_get_response = client.get( + f"/libraries/{library_id}/entities/by-filepath", + params={"filepath": "nonexistent.txt"}, + ) assert invalid_get_response.status_code == 404 assert invalid_get_response.json() == {"detail": "Entity not found"} # Test for library not found - invalid_get_response = client.get(f"/libraries/9999/entities/by-filepath", params={"filepath": new_entity.filepath}) + invalid_get_response = client.get( + f"/libraries/9999/entities/by-filepath", + params={"filepath": new_entity.filepath}, + ) assert invalid_get_response.status_code == 404 assert invalid_get_response.json() == {"detail": "Entity not found"} + + +def test_list_entities_in_folder(client): + # Setup data: Create a new library and folder + new_library = NewLibraryParam( + name="Library for List Entities Test", folders=["/tmp"] + ) + library_response = client.post( + "/libraries", json=new_library.model_dump(mode="json") + ) + library_id = library_response.json()["id"] + + new_folder = NewFolderParam(path="/tmp") + folder_response = client.post( + f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json") + ) + folder_id = folder_response.json()["id"] + + # Create a new entity in the folder + new_entity = NewEntityParam( + filename="test_list.txt", + filepath="test_list.txt", + size=100, + file_created_at="2023-01-01T00:00:00", + file_last_modified_at="2023-01-01T00:00:00", + file_type="text/plain", + folder_id=folder_id, + ) + entity_response = client.post( + f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json") + ) + entity_id = entity_response.json()["id"] + + # List entities in the folder + list_response = client.get(f"/libraries/{library_id}/folders/{folder_id}/entities") + + # Check that the response is successful + assert list_response.status_code == 200 + + # Check the response data + entities_data = list_response.json() + assert len(entities_data) == 1 + assert entities_data[0]["id"] == entity_id + assert entities_data[0]["filepath"] == new_entity.filepath + assert entities_data[0]["filename"] == new_entity.filename + assert entities_data[0]["size"] == new_entity.size + assert entities_data[0]["file_type"] == new_entity.file_type + + # Test for folder not found + invalid_list_response = client.get(f"/libraries/{library_id}/folders/9999/entities") + assert invalid_list_response.status_code == 404 + assert invalid_list_response.json() == { + "detail": "Folder not found in the specified library" + } + + # Test for library not found + invalid_list_response = client.get(f"/libraries/9999/folders/{folder_id}/entities") + assert invalid_list_response.status_code == 404 + assert invalid_list_response.json() == {"detail": "Library not found"}