From fe076da948b0b602039b8608272ebeb5dffd14f0 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:10:42 +0800 Subject: [PATCH] feat: prevent duplicated library name --- memos/crud.py | 68 +++++++++++++++++++++++++++++++++++--------- memos/models.py | 2 +- memos/server.py | 5 ++++ memos/test_server.py | 6 ++++ 4 files changed, 66 insertions(+), 15 deletions(-) diff --git a/memos/crud.py b/memos/crud.py index 67f0fb5..7080d37 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -1,7 +1,25 @@ from typing import List from sqlalchemy.orm import Session -from .schemas import Library, NewLibraryParam, Folder, NewEntityParam, Entity, Plugin, NewPluginParam, UpdateEntityParam, NewFolderParam -from .models import LibraryModel, FolderModel, EntityModel, EntityModel, PluginModel, LibraryPluginModel +from sqlalchemy import func +from .schemas import ( + Library, + NewLibraryParam, + Folder, + NewEntityParam, + Entity, + Plugin, + NewPluginParam, + UpdateEntityParam, + NewFolderParam, +) +from .models import ( + LibraryModel, + FolderModel, + EntityModel, + EntityModel, + PluginModel, + LibraryPluginModel, +) def get_library_by_id(library_id: int, db: Session) -> Library | None: @@ -22,8 +40,11 @@ def create_library(library: NewLibraryParam, db: Session) -> Library: return Library( id=db_library.id, name=db_library.name, - folders=[Folder(id=db_folder.id, path=db_folder.path) for db_folder in db_library.folders], - plugins=[] + folders=[ + Folder(id=db_folder.id, path=db_folder.path) + for db_folder in db_library.folders + ], + plugins=[], ) @@ -31,6 +52,14 @@ def get_libraries(db: Session) -> List[Library]: return db.query(LibraryModel).all() +def get_library_by_name(library_name: str, db: Session) -> Library | None: + return ( + db.query(LibraryModel) + .filter(func.lower(LibraryModel.name) == library_name.lower()) + .first() + ) + + def add_folder(library_id: int, folder: NewFolderParam, db: Session) -> Folder: db_folder = FolderModel(path=str(folder.path), library_id=library_id) db.add(db_folder) @@ -40,10 +69,7 @@ def add_folder(library_id: int, folder: NewFolderParam, db: Session) -> Folder: def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity: - db_entity = EntityModel( - **entity.model_dump(), - library_id=library_id - ) + db_entity = EntityModel(**entity.model_dump(), library_id=library_id) db.add(db_entity) db.commit() db.refresh(db_entity) @@ -54,12 +80,24 @@ def get_entity_by_id(entity_id: int, db: Session) -> Entity | None: return db.query(EntityModel).filter(EntityModel.id == entity_id).first() -def get_entities_of_folder(library_id: int, folder_id: int, db: Session, limit: int = 10, offset: int = 0) -> List[Entity]: - folder = db.query(FolderModel).filter(FolderModel.id == folder_id, FolderModel.library_id == library_id).first() +def get_entities_of_folder( + library_id: int, folder_id: int, db: Session, limit: int = 10, offset: int = 0 +) -> List[Entity]: + folder = ( + db.query(FolderModel) + .filter(FolderModel.id == folder_id, FolderModel.library_id == library_id) + .first() + ) if folder is None: return [] - entities = db.query(EntityModel).filter(EntityModel.folder_id == folder_id).limit(limit).offset(offset).all() + entities = ( + db.query(EntityModel) + .filter(EntityModel.folder_id == folder_id) + .limit(limit) + .offset(offset) + .all() + ) return entities @@ -77,7 +115,7 @@ def remove_entity(entity_id: int, db: Session): def create_plugin(newPlugin: NewPluginParam, db: Session) -> Plugin: - db_plugin = PluginModel(**newPlugin.model_dump(mode='json')) + db_plugin = PluginModel(**newPlugin.model_dump(mode="json")) db.add(db_plugin) db.commit() db.refresh(db_plugin) @@ -95,10 +133,12 @@ def get_entity_by_id(entity_id: int, db: Session) -> Entity | None: return db.query(EntityModel).filter(EntityModel.id == entity_id).first() -def update_entity(entity_id: int, updated_entity: UpdateEntityParam, db: Session) -> Entity: +def update_entity( + entity_id: int, updated_entity: UpdateEntityParam, db: Session +) -> Entity: db_entity = get_entity_by_id(entity_id, db) for key, value in updated_entity.model_dump().items(): setattr(db_entity, key, value) db.commit() db.refresh(db_entity) - return db_entity \ No newline at end of file + return db_entity diff --git a/memos/models.py b/memos/models.py index 1da0cbd..9838f16 100644 --- a/memos/models.py +++ b/memos/models.py @@ -27,7 +27,7 @@ class Base(DeclarativeBase): class LibraryModel(Base): __tablename__ = "libraries" - name: Mapped[str] = mapped_column(String, nullable=False) + name: Mapped[str] = mapped_column(String, nullable=False, unique=True) folders: Mapped[List["FolderModel"]] = relationship( "FolderModel", back_populates="library", lazy="joined" ) diff --git a/memos/server.py b/memos/server.py index e9f8928..a4ac625 100644 --- a/memos/server.py +++ b/memos/server.py @@ -41,6 +41,11 @@ def root(): @app.post("/libraries", response_model=Library) def new_library(library_param: NewLibraryParam, db: Session = Depends(get_db)): + # Check if a library with the same name (case insensitive) already exists + existing_library = crud.get_library_by_name(library_param.name, db) + if existing_library: + raise HTTPException(status_code=400, detail="Library with this name already exists") + # Remove duplicate folders from the library_param unique_folders = list(set(library_param.folders)) library_param.folders = unique_folders diff --git a/memos/test_server.py b/memos/test_server.py index 4a2bf4e..7e35d2f 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -67,6 +67,12 @@ def test_new_library(client): "plugins": [], } + # Test for duplicate library name + duplicate_response = client.post("/libraries", json=library_param.model_dump()) + # Check that the response indicates a failure due to duplicate name + assert duplicate_response.status_code == 400 + assert duplicate_response.json() == {"detail": "Library with this name already exists"} + def test_list_libraries(client): # Setup data: Create a new library