feat: prevent duplicated library name

This commit is contained in:
arkohut 2024-06-04 17:10:42 +08:00
parent e36c0a6bec
commit fe076da948
4 changed files with 66 additions and 15 deletions

View File

@ -1,7 +1,25 @@
from typing import List from typing import List
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from .schemas import Library, NewLibraryParam, Folder, NewEntityParam, Entity, Plugin, NewPluginParam, UpdateEntityParam, NewFolderParam from sqlalchemy import func
from .models import LibraryModel, FolderModel, EntityModel, EntityModel, PluginModel, LibraryPluginModel from .schemas import (
Library,
NewLibraryParam,
Folder,
NewEntityParam,
Entity,
Plugin,
NewPluginParam,
UpdateEntityParam,
NewFolderParam,
)
from .models import (
LibraryModel,
FolderModel,
EntityModel,
EntityModel,
PluginModel,
LibraryPluginModel,
)
def get_library_by_id(library_id: int, db: Session) -> Library | None: def get_library_by_id(library_id: int, db: Session) -> Library | None:
@ -22,8 +40,11 @@ def create_library(library: NewLibraryParam, db: Session) -> Library:
return Library( return Library(
id=db_library.id, id=db_library.id,
name=db_library.name, name=db_library.name,
folders=[Folder(id=db_folder.id, path=db_folder.path) for db_folder in db_library.folders], folders=[
plugins=[] Folder(id=db_folder.id, path=db_folder.path)
for db_folder in db_library.folders
],
plugins=[],
) )
@ -31,6 +52,14 @@ def get_libraries(db: Session) -> List[Library]:
return db.query(LibraryModel).all() return db.query(LibraryModel).all()
def get_library_by_name(library_name: str, db: Session) -> Library | None:
return (
db.query(LibraryModel)
.filter(func.lower(LibraryModel.name) == library_name.lower())
.first()
)
def add_folder(library_id: int, folder: NewFolderParam, db: Session) -> Folder: def add_folder(library_id: int, folder: NewFolderParam, db: Session) -> Folder:
db_folder = FolderModel(path=str(folder.path), library_id=library_id) db_folder = FolderModel(path=str(folder.path), library_id=library_id)
db.add(db_folder) db.add(db_folder)
@ -40,10 +69,7 @@ def add_folder(library_id: int, folder: NewFolderParam, db: Session) -> Folder:
def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity: def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity:
db_entity = EntityModel( db_entity = EntityModel(**entity.model_dump(), library_id=library_id)
**entity.model_dump(),
library_id=library_id
)
db.add(db_entity) db.add(db_entity)
db.commit() db.commit()
db.refresh(db_entity) db.refresh(db_entity)
@ -54,12 +80,24 @@ def get_entity_by_id(entity_id: int, db: Session) -> Entity | None:
return db.query(EntityModel).filter(EntityModel.id == entity_id).first() return db.query(EntityModel).filter(EntityModel.id == entity_id).first()
def get_entities_of_folder(library_id: int, folder_id: int, db: Session, limit: int = 10, offset: int = 0) -> List[Entity]: def get_entities_of_folder(
folder = db.query(FolderModel).filter(FolderModel.id == folder_id, FolderModel.library_id == library_id).first() library_id: int, folder_id: int, db: Session, limit: int = 10, offset: int = 0
) -> List[Entity]:
folder = (
db.query(FolderModel)
.filter(FolderModel.id == folder_id, FolderModel.library_id == library_id)
.first()
)
if folder is None: if folder is None:
return [] return []
entities = db.query(EntityModel).filter(EntityModel.folder_id == folder_id).limit(limit).offset(offset).all() entities = (
db.query(EntityModel)
.filter(EntityModel.folder_id == folder_id)
.limit(limit)
.offset(offset)
.all()
)
return entities return entities
@ -77,7 +115,7 @@ def remove_entity(entity_id: int, db: Session):
def create_plugin(newPlugin: NewPluginParam, db: Session) -> Plugin: def create_plugin(newPlugin: NewPluginParam, db: Session) -> Plugin:
db_plugin = PluginModel(**newPlugin.model_dump(mode='json')) db_plugin = PluginModel(**newPlugin.model_dump(mode="json"))
db.add(db_plugin) db.add(db_plugin)
db.commit() db.commit()
db.refresh(db_plugin) db.refresh(db_plugin)
@ -95,10 +133,12 @@ def get_entity_by_id(entity_id: int, db: Session) -> Entity | None:
return db.query(EntityModel).filter(EntityModel.id == entity_id).first() return db.query(EntityModel).filter(EntityModel.id == entity_id).first()
def update_entity(entity_id: int, updated_entity: UpdateEntityParam, db: Session) -> Entity: def update_entity(
entity_id: int, updated_entity: UpdateEntityParam, db: Session
) -> Entity:
db_entity = get_entity_by_id(entity_id, db) db_entity = get_entity_by_id(entity_id, db)
for key, value in updated_entity.model_dump().items(): for key, value in updated_entity.model_dump().items():
setattr(db_entity, key, value) setattr(db_entity, key, value)
db.commit() db.commit()
db.refresh(db_entity) db.refresh(db_entity)
return db_entity return db_entity

View File

@ -27,7 +27,7 @@ class Base(DeclarativeBase):
class LibraryModel(Base): class LibraryModel(Base):
__tablename__ = "libraries" __tablename__ = "libraries"
name: Mapped[str] = mapped_column(String, nullable=False) name: Mapped[str] = mapped_column(String, nullable=False, unique=True)
folders: Mapped[List["FolderModel"]] = relationship( folders: Mapped[List["FolderModel"]] = relationship(
"FolderModel", back_populates="library", lazy="joined" "FolderModel", back_populates="library", lazy="joined"
) )

View File

@ -41,6 +41,11 @@ def root():
@app.post("/libraries", response_model=Library) @app.post("/libraries", response_model=Library)
def new_library(library_param: NewLibraryParam, db: Session = Depends(get_db)): def new_library(library_param: NewLibraryParam, db: Session = Depends(get_db)):
# Check if a library with the same name (case insensitive) already exists
existing_library = crud.get_library_by_name(library_param.name, db)
if existing_library:
raise HTTPException(status_code=400, detail="Library with this name already exists")
# Remove duplicate folders from the library_param # Remove duplicate folders from the library_param
unique_folders = list(set(library_param.folders)) unique_folders = list(set(library_param.folders))
library_param.folders = unique_folders library_param.folders = unique_folders

View File

@ -67,6 +67,12 @@ def test_new_library(client):
"plugins": [], "plugins": [],
} }
# Test for duplicate library name
duplicate_response = client.post("/libraries", json=library_param.model_dump())
# Check that the response indicates a failure due to duplicate name
assert duplicate_response.status_code == 400
assert duplicate_response.json() == {"detail": "Library with this name already exists"}
def test_list_libraries(client): def test_list_libraries(client):
# Setup data: Create a new library # Setup data: Create a new library