feat(folder): list entities in folder

This commit is contained in:
arkohut 2024-06-04 16:20:35 +08:00
parent 890245d654
commit 96b219115d
3 changed files with 166 additions and 30 deletions

View File

@ -1,6 +1,6 @@
from typing import List from typing import List
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from .schemas import Library, NewLibraryParam, Folder, NewEntityParam, Entity, Plugin, NewPluginParam, UpdateEntityParam from .schemas import Library, NewLibraryParam, Folder, NewEntityParam, Entity, Plugin, NewPluginParam, UpdateEntityParam, NewFolderParam
from .models import LibraryModel, FolderModel, EntityModel, EntityModel, PluginModel, LibraryPluginModel from .models import LibraryModel, FolderModel, EntityModel, EntityModel, PluginModel, LibraryPluginModel
@ -31,6 +31,14 @@ def get_libraries(db: Session) -> List[Library]:
return db.query(LibraryModel).all() return db.query(LibraryModel).all()
def add_folder(library_id: int, folder: NewFolderParam, db: Session) -> Folder:
db_folder = FolderModel(path=str(folder.path), library_id=library_id)
db.add(db_folder)
db.commit()
db.refresh(db_folder)
return Folder(id=db_folder.id, path=db_folder.path)
def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity: def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity:
db_entity = EntityModel( db_entity = EntityModel(
**entity.model_dump(), **entity.model_dump(),
@ -46,6 +54,15 @@ def get_entity_by_id(entity_id: int, db: Session) -> Entity | None:
return db.query(EntityModel).filter(EntityModel.id == entity_id).first() return db.query(EntityModel).filter(EntityModel.id == entity_id).first()
def get_entities_of_folder(library_id: int, folder_id: int, db: Session, limit: int = 10, offset: int = 0) -> List[Entity]:
folder = db.query(FolderModel).filter(FolderModel.id == folder_id, FolderModel.library_id == library_id).first()
if folder is None:
return []
entities = db.query(EntityModel).filter(EntityModel.folder_id == folder_id).limit(limit).offset(offset).all()
return entities
def get_entity_by_filepath(filepath: str, db: Session) -> Entity | None: def get_entity_by_filepath(filepath: str, db: Session) -> Entity | None:
return db.query(EntityModel).filter(EntityModel.filepath == filepath).first() return db.query(EntityModel).filter(EntityModel.filepath == filepath).first()

View File

@ -1,9 +1,9 @@
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException, Depends, status from fastapi import FastAPI, HTTPException, Depends, status, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from typing import List from typing import List, Annotated
from .config import get_database_path from .config import get_database_path
import memos.crud as crud import memos.crud as crud
@ -69,11 +69,7 @@ def new_folder(
if library is None: if library is None:
raise HTTPException(status_code=404, detail="Library not found") raise HTTPException(status_code=404, detail="Library not found")
db_folder = Folder(path=folder.path, library_id=library.id) return crud.add_folder(library_id=library.id, folder=folder, db=db)
db.add(db_folder)
db.commit()
db.refresh(db_folder)
return db_folder
@app.post("/libraries/{library_id}/entities", response_model=Entity) @app.post("/libraries/{library_id}/entities", response_model=Entity)
@ -88,6 +84,28 @@ def new_entity(
return entity return entity
@app.get(
"/libraries/{library_id}/folders/{folder_id}/entities", response_model=List[Entity]
)
def list_entities_in_folder(
library_id: int,
folder_id: int,
limit: Annotated[int, Query(ge=1, le=200)] = 10,
offset: int = 0,
db: Session = Depends(get_db),
):
library = crud.get_library_by_id(library_id, db)
if library is None:
raise HTTPException(status_code=404, detail="Library not found")
if folder_id not in [folder.id for folder in library.folders]:
raise HTTPException(
status_code=404, detail="Folder not found in the specified library"
)
return crud.get_entities_of_folder(library_id, folder_id, db, limit, offset)
@app.get("/libraries/{library_id}/entities/by-filepath", response_model=Entity) @app.get("/libraries/{library_id}/entities/by-filepath", response_model=Entity)
def get_entity_by_filepath( def get_entity_by_filepath(
library_id: int, filepath: str, db: Session = Depends(get_db) library_id: int, filepath: str, db: Session = Depends(get_db)

View File

@ -8,7 +8,13 @@ from pathlib import Path
from memos.server import app, get_db from memos.server import app, get_db
from memos.schemas import Library, NewLibraryParam, NewEntityParam, UpdateEntityParam from memos.schemas import (
Library,
NewLibraryParam,
NewEntityParam,
UpdateEntityParam,
NewFolderParam,
)
from memos.models import Base from memos.models import Base
@ -84,10 +90,13 @@ def test_list_libraries(client):
] ]
assert response.json() == expected_data assert response.json() == expected_data
def test_new_entity(client): def test_new_entity(client):
# Setup data: Create a new library # Setup data: Create a new library
new_library = NewLibraryParam(name="Library for Entity Test", folders=["/tmp"]) new_library = NewLibraryParam(name="Library for Entity Test", folders=["/tmp"])
library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) library_response = client.post(
"/libraries", json=new_library.model_dump(mode="json")
)
library_id = library_response.json()["id"] library_id = library_response.json()["id"]
folder_id = library_response.json()["folders"][0]["id"] folder_id = library_response.json()["folders"][0]["id"]
@ -99,9 +108,11 @@ def test_new_entity(client):
file_created_at="2023-01-01T00:00:00", file_created_at="2023-01-01T00:00:00",
file_last_modified_at="2023-01-01T00:00:00", file_last_modified_at="2023-01-01T00:00:00",
file_type="text/plain", file_type="text/plain",
folder_id=folder_id folder_id=folder_id,
)
entity_response = client.post(
f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")
) )
entity_response = client.post(f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json"))
# Check that the response is successful # Check that the response is successful
assert entity_response.status_code == 200 assert entity_response.status_code == 200
@ -117,16 +128,19 @@ def test_new_entity(client):
assert entity_data["folder_id"] == 1 assert entity_data["folder_id"] == 1
# Test for library not found # Test for library not found
invalid_entity_response = client.post("/libraries/9999/entities", json=new_entity.model_dump(mode="json")) invalid_entity_response = client.post(
"/libraries/9999/entities", json=new_entity.model_dump(mode="json")
)
assert invalid_entity_response.status_code == 404 assert invalid_entity_response.status_code == 404
assert invalid_entity_response.json() == {"detail": "Library not found"} assert invalid_entity_response.json() == {"detail": "Library not found"}
def test_update_entity(client): def test_update_entity(client):
# Setup data: Create a new library and entity # Setup data: Create a new library and entity
new_library = NewLibraryParam(name="Library for Update Test", folders=["/tmp"]) new_library = NewLibraryParam(name="Library for Update Test", folders=["/tmp"])
library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) library_response = client.post(
"/libraries", json=new_library.model_dump(mode="json")
)
library_id = library_response.json()["id"] library_id = library_response.json()["id"]
new_entity = NewEntityParam( new_entity = NewEntityParam(
@ -136,9 +150,11 @@ def test_update_entity(client):
file_created_at="2023-01-01T00:00:00", file_created_at="2023-01-01T00:00:00",
file_last_modified_at="2023-01-01T00:00:00", file_last_modified_at="2023-01-01T00:00:00",
file_type="text/plain", file_type="text/plain",
folder_id=1 folder_id=1,
)
entity_response = client.post(
f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")
) )
entity_response = client.post(f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json"))
entity_id = entity_response.json()["id"] entity_id = entity_response.json()["id"]
# Update the entity # Update the entity
@ -146,9 +162,12 @@ def test_update_entity(client):
size=200, size=200,
file_created_at="2023-01-02T00:00:00", file_created_at="2023-01-02T00:00:00",
file_last_modified_at="2023-01-02T00:00:00", file_last_modified_at="2023-01-02T00:00:00",
file_type="text/markdown" file_type="text/markdown",
)
update_response = client.put(
f"/libraries/{library_id}/entities/{entity_id}",
json=updated_entity.model_dump(mode="json"),
) )
update_response = client.put(f"/libraries/{library_id}/entities/{entity_id}", json=updated_entity.model_dump(mode="json"))
# Check that the response is successful # Check that the response is successful
assert update_response.status_code == 200 assert update_response.status_code == 200
@ -162,21 +181,33 @@ def test_update_entity(client):
assert updated_data["file_type"] == "text/markdown" assert updated_data["file_type"] == "text/markdown"
# Test for entity not found # Test for entity not found
invalid_update_response = client.put(f"/libraries/{library_id}/entities/9999", json=updated_entity.model_dump(mode="json")) invalid_update_response = client.put(
f"/libraries/{library_id}/entities/9999",
json=updated_entity.model_dump(mode="json"),
)
assert invalid_update_response.status_code == 404 assert invalid_update_response.status_code == 404
assert invalid_update_response.json() == {"detail": "Entity not found in the specified library"} assert invalid_update_response.json() == {
"detail": "Entity not found in the specified library"
}
# Test for library not found # Test for library not found
invalid_update_response = client.put(f"/libraries/9999/entities/{entity_id}", json=updated_entity.model_dump(mode="json")) invalid_update_response = client.put(
f"/libraries/9999/entities/{entity_id}",
json=updated_entity.model_dump(mode="json"),
)
assert invalid_update_response.status_code == 404 assert invalid_update_response.status_code == 404
assert invalid_update_response.json() == {"detail": "Entity not found in the specified library"} assert invalid_update_response.json() == {
"detail": "Entity not found in the specified library"
}
# Test for getting an entity by filepath # Test for getting an entity by filepath
def test_get_entity_by_filepath(client): def test_get_entity_by_filepath(client):
# Setup data: Create a new library and entity # Setup data: Create a new library and entity
new_library = NewLibraryParam(name="Library for Get Entity Test", folders=["/tmp"]) new_library = NewLibraryParam(name="Library for Get Entity Test", folders=["/tmp"])
library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) library_response = client.post(
"/libraries", json=new_library.model_dump(mode="json")
)
library_id = library_response.json()["id"] library_id = library_response.json()["id"]
new_entity = NewEntityParam( new_entity = NewEntityParam(
@ -186,13 +217,18 @@ def test_get_entity_by_filepath(client):
file_created_at="2023-01-01T00:00:00", file_created_at="2023-01-01T00:00:00",
file_last_modified_at="2023-01-01T00:00:00", file_last_modified_at="2023-01-01T00:00:00",
file_type="text/plain", file_type="text/plain",
folder_id=1 folder_id=1,
)
entity_response = client.post(
f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")
) )
entity_response = client.post(f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json"))
entity_id = entity_response.json()["id"] entity_id = entity_response.json()["id"]
get_response = client.get(f"/libraries/{library_id}/entities/by-filepath", params={"filepath": new_entity.filepath}) get_response = client.get(
f"/libraries/{library_id}/entities/by-filepath",
params={"filepath": new_entity.filepath},
)
# Check that the response is successful # Check that the response is successful
assert get_response.status_code == 200 assert get_response.status_code == 200
@ -205,11 +241,76 @@ def test_get_entity_by_filepath(client):
assert entity_data["file_type"] == new_entity.file_type assert entity_data["file_type"] == new_entity.file_type
# Test for entity not found # Test for entity not found
invalid_get_response = client.get(f"/libraries/{library_id}/entities/by-filepath", params={"filepath": "nonexistent.txt"}) invalid_get_response = client.get(
f"/libraries/{library_id}/entities/by-filepath",
params={"filepath": "nonexistent.txt"},
)
assert invalid_get_response.status_code == 404 assert invalid_get_response.status_code == 404
assert invalid_get_response.json() == {"detail": "Entity not found"} assert invalid_get_response.json() == {"detail": "Entity not found"}
# Test for library not found # Test for library not found
invalid_get_response = client.get(f"/libraries/9999/entities/by-filepath", params={"filepath": new_entity.filepath}) invalid_get_response = client.get(
f"/libraries/9999/entities/by-filepath",
params={"filepath": new_entity.filepath},
)
assert invalid_get_response.status_code == 404 assert invalid_get_response.status_code == 404
assert invalid_get_response.json() == {"detail": "Entity not found"} assert invalid_get_response.json() == {"detail": "Entity not found"}
def test_list_entities_in_folder(client):
# Setup data: Create a new library and folder
new_library = NewLibraryParam(
name="Library for List Entities Test", folders=["/tmp"]
)
library_response = client.post(
"/libraries", json=new_library.model_dump(mode="json")
)
library_id = library_response.json()["id"]
new_folder = NewFolderParam(path="/tmp")
folder_response = client.post(
f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")
)
folder_id = folder_response.json()["id"]
# Create a new entity in the folder
new_entity = NewEntityParam(
filename="test_list.txt",
filepath="test_list.txt",
size=100,
file_created_at="2023-01-01T00:00:00",
file_last_modified_at="2023-01-01T00:00:00",
file_type="text/plain",
folder_id=folder_id,
)
entity_response = client.post(
f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json")
)
entity_id = entity_response.json()["id"]
# List entities in the folder
list_response = client.get(f"/libraries/{library_id}/folders/{folder_id}/entities")
# Check that the response is successful
assert list_response.status_code == 200
# Check the response data
entities_data = list_response.json()
assert len(entities_data) == 1
assert entities_data[0]["id"] == entity_id
assert entities_data[0]["filepath"] == new_entity.filepath
assert entities_data[0]["filename"] == new_entity.filename
assert entities_data[0]["size"] == new_entity.size
assert entities_data[0]["file_type"] == new_entity.file_type
# Test for folder not found
invalid_list_response = client.get(f"/libraries/{library_id}/folders/9999/entities")
assert invalid_list_response.status_code == 404
assert invalid_list_response.json() == {
"detail": "Folder not found in the specified library"
}
# Test for library not found
invalid_list_response = client.get(f"/libraries/9999/folders/{folder_id}/entities")
assert invalid_list_response.status_code == 404
assert invalid_list_response.json() == {"detail": "Library not found"}