From dd3b32821d919a7f63b708a70af3b9f5cc60edd2 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Sun, 2 Jun 2024 00:11:14 +0800 Subject: [PATCH] feat: list libraries --- memos/crud.py | 7 +++- memos/models.py | 76 ++++++++++++++++++++++++++++----------- memos/schemas.py | 2 +- memos/server.py | 20 +++++++++-- memos/test_server.py | 84 +++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 159 insertions(+), 30 deletions(-) diff --git a/memos/crud.py b/memos/crud.py index a6baa3d..a3e052e 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -1,3 +1,4 @@ +from typing import List from sqlalchemy.orm import Session from .schemas import Library, NewLibraryParam, Folder, NewEntityParam, Entity, Plugin, NewPluginParam from .models import LibraryModel, FolderModel, EntityModel, EntityModel, PluginModel, LibraryPluginModel @@ -21,11 +22,15 @@ def create_library(library: NewLibraryParam, db: Session) -> Library: return Library( id=db_library.id, name=db_library.name, - folders=[Folder(id=db_folder.id, name=db_folder.path) for db_folder in db_library.folders], + folders=[Folder(id=db_folder.id, path=db_folder.path) for db_folder in db_library.folders], plugins=[] ) +def get_libraries(db: Session) -> List[Library]: + return db.query(LibraryModel).all() + + def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity: db_entity = EntityModel( **entity.model_dump(), diff --git a/memos/models.py b/memos/models.py index 32e4489..911e378 100644 --- a/memos/models.py +++ b/memos/models.py @@ -6,7 +6,7 @@ from sqlalchemy import ( DateTime, Enum, ForeignKey, - func + func, ) from datetime import datetime from sqlalchemy.orm import relationship, DeclarativeBase, Mapped, mapped_column @@ -17,23 +17,37 @@ from .schemas import MetadataSource, MetadataType class Base(DeclarativeBase): id: Mapped[int] = mapped_column(Integer, primary_key=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.now(), nullable=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.now(), onupdate=func.now(), nullable=False + ) class LibraryModel(Base): __tablename__ = "libraries" name: Mapped[str] = mapped_column(String, nullable=False) - folders: Mapped[List["FolderModel"]] = relationship("FolderModel", back_populates="library") - plugins: Mapped[List["PluginModel"]] = relationship("LibraryPluginModel", back_populates="library") + folders: Mapped[List["FolderModel"]] = relationship( + "FolderModel", back_populates="library", lazy="joined" + ) + plugins: Mapped[List["PluginModel"]] = relationship( + "LibraryPluginModel", back_populates="library", lazy="joined" + ) class FolderModel(Base): __tablename__ = "folders" path: Mapped[str] = mapped_column(String, nullable=False) - library_id: Mapped[int] = mapped_column(Integer, ForeignKey('libraries.id'), nullable=False) - library: Mapped["LibraryModel"] = relationship("LibraryModel", back_populates="folders") - entities: Mapped[List["EntityModel"]] = relationship("EntityModel", back_populates="folder") + library_id: Mapped[int] = mapped_column( + Integer, ForeignKey("libraries.id"), nullable=False + ) + library: Mapped["LibraryModel"] = relationship( + "LibraryModel", back_populates="folders" + ) + entities: Mapped[List["EntityModel"]] = relationship( + "EntityModel", back_populates="folder" + ) class EntityModel(Base): @@ -45,10 +59,18 @@ class EntityModel(Base): file_last_modified_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) file_type: Mapped[str] = mapped_column(String, nullable=False) last_scan_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - library_id: Mapped[int] = mapped_column(Integer, ForeignKey('libraries.id'), nullable=False) - folder_id: Mapped[int] = mapped_column(Integer, ForeignKey('folders.id'), nullable=False) - folder: Mapped["FolderModel"] = relationship("FolderModel", back_populates="entities") - metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship("EntityMetadataModel") + library_id: Mapped[int] = mapped_column( + Integer, ForeignKey("libraries.id"), nullable=False + ) + folder_id: Mapped[int] = mapped_column( + Integer, ForeignKey("folders.id"), nullable=False + ) + folder: Mapped["FolderModel"] = relationship( + "FolderModel", back_populates="entities" + ) + metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship( + "EntityMetadataModel" + ) tags: Mapped[List["TagModel"]] = relationship("EntityTagModel") @@ -62,17 +84,23 @@ class TagModel(Base): class EntityTagModel(Base): __tablename__ = "entity_tags" - entity_id: Mapped[int] = mapped_column(Integer, ForeignKey('entities.id'), nullable=False) - tag_id: Mapped[int] = mapped_column(Integer, ForeignKey('tags.id'), nullable=False) + entity_id: Mapped[int] = mapped_column( + Integer, ForeignKey("entities.id"), nullable=False + ) + tag_id: Mapped[int] = mapped_column(Integer, ForeignKey("tags.id"), nullable=False) source: Mapped[MetadataSource] = mapped_column(Enum(MetadataSource), nullable=False) class EntityMetadataModel(Base): __tablename__ = "metadata_entries" - entity_id: Mapped[int] = mapped_column(Integer, ForeignKey('entities.id'), nullable=False) + entity_id: Mapped[int] = mapped_column( + Integer, ForeignKey("entities.id"), nullable=False + ) key: Mapped[str] = mapped_column(String, nullable=False) value: Mapped[str] = mapped_column(Text, nullable=False) - source_type: Mapped[MetadataSource] = mapped_column(Enum(MetadataSource), nullable=False) + source_type: Mapped[MetadataSource] = mapped_column( + Enum(MetadataSource), nullable=False + ) source: Mapped[str | None] = mapped_column(String, nullable=True) date_type: Mapped[MetadataType] = mapped_column(Enum(MetadataType), nullable=False) entity = relationship("EntityModel", back_populates="metadata_entries") @@ -88,10 +116,18 @@ class PluginModel(Base): class LibraryPluginModel(Base): __tablename__ = "library_plugins" - library_id: Mapped[int] = mapped_column(Integer, ForeignKey('libraries.id'), nullable=False) - plugin_id: Mapped[int] = mapped_column(Integer, ForeignKey('plugins.id'), nullable=False) - library: Mapped["LibraryModel"] = relationship("LibraryModel", back_populates="plugins") - plugin: Mapped["PluginModel"] = relationship("PluginModel", back_populates="libraries") + library_id: Mapped[int] = mapped_column( + Integer, ForeignKey("libraries.id"), nullable=False + ) + plugin_id: Mapped[int] = mapped_column( + Integer, ForeignKey("plugins.id"), nullable=False + ) + library: Mapped["LibraryModel"] = relationship( + "LibraryModel", back_populates="plugins" + ) + plugin: Mapped["PluginModel"] = relationship( + "PluginModel", back_populates="libraries" + ) # Create the database engine with the path from config diff --git a/memos/schemas.py b/memos/schemas.py index 1c4b9c1..22927a3 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -66,7 +66,7 @@ class NewLibraryPluginParam(BaseModel): class Folder(BaseModel): id: int - name: str + path: str model_config = ConfigDict(from_attributes=True) diff --git a/memos/server.py b/memos/server.py index f3c0881..82144ed 100644 --- a/memos/server.py +++ b/memos/server.py @@ -6,7 +6,14 @@ from sqlalchemy.orm import sessionmaker from typing import List from .config import get_database_path -from .crud import get_library_by_id, create_library, create_entity, create_plugin, add_plugin_to_library +from .crud import ( + get_library_by_id, + create_library, + create_entity, + create_plugin, + add_plugin_to_library, + get_libraries, +) from .schemas import ( Library, Folder, @@ -44,15 +51,22 @@ def new_library(library_param: NewLibraryParam, db: Session = Depends(get_db)): return library +@app.get("/libraries", response_model=List[Library]) +def list_libraries(db: Session = Depends(get_db)): + libraries = get_libraries(db) + return libraries + + @app.post("/libraries/{library_id}/folders", response_model=Folder) def new_folder( library_id: int, - folder: NewFolderParam, db: Session = Depends(get_db), + folder: NewFolderParam, + db: Session = Depends(get_db), ): library = get_library_by_id(library_id, db) if library is None: raise HTTPException(status_code=404, detail="Library not found") - + db_folder = Folder(path=folder.path, library_id=library.id) db.add(db_folder) db.commit() diff --git a/memos/test_server.py b/memos/test_server.py index ca7c4e1..8f7aab0 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -1,11 +1,85 @@ +import pytest + from fastapi.testclient import TestClient - -from .server import app - -client = TestClient(app) +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool +from pathlib import Path -def test_read_main(): +from memos.server import app, get_db +from memos.schemas import Library, NewLibraryParam +from memos.models import Base + + +engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def override_get_db(): + try: + db = TestingSessionLocal() + yield db + finally: + db.close() + + +app.dependency_overrides[get_db] = override_get_db + + +# Setup a fixture for the FastAPI test client +@pytest.fixture +def client(): + Base.metadata.create_all(bind=engine) + with TestClient(app) as client: + yield client + Base.metadata.drop_all(bind=engine) + + +def test_read_main(client): response = client.get("/") assert response.status_code == 200 assert response.json() == {"healthy": True} + + +# Test the new_library endpoint +def test_new_library(client): + library_param = NewLibraryParam(name="Test Library") + # Make a POST request to the /libraries endpoint + response = client.post("/libraries", json=library_param.model_dump()) + # Check that the response is successful + assert response.status_code == 200 + # Check the response data + assert response.json() == { + "id": 1, + "name": "Test Library", + "folders": [], + "plugins": [], + } + + +def test_list_libraries(client): + # Setup data: Create a new library + new_library = NewLibraryParam(name="Sample Library", folders=["/tmp"]) + client.post("/libraries", json=new_library.model_dump(mode="json")) + + # Make a GET request to the /libraries endpoint + response = client.get("/libraries") + + # Check that the response is successful + assert response.status_code == 200 + + # Check the response data + expected_data = [ + { + "id": 1, + "name": "Sample Library", + "folders": [{"id": 1, "path": "/tmp"}], + "plugins": [], + } + ] + assert response.json() == expected_data