fix(entity): entity update issues

This commit is contained in:
arkohut 2024-06-13 11:15:27 +08:00
parent d6a9241ffc
commit 7734b40848
3 changed files with 33 additions and 12 deletions

View File

@ -153,16 +153,22 @@ def get_entity_by_id(entity_id: int, db: Session) -> Entity | None:
def update_entity( def update_entity(
entity_id: int, updated_entity: UpdateEntityParam, db: Session entity_id: int, updated_entity: UpdateEntityParam, db: Session
) -> Entity: ) -> Entity:
db_entity = get_entity_by_id(entity_id, db) db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
if db_entity is None:
raise ValueError(f"Entity with id {entity_id} not found")
# Update the main fields of the entity # Update the main fields of the entity
for key, value in updated_entity.model_dump().items(): for key, value in updated_entity.model_dump().items():
if key not in ["tags", "attrs"] and value is not None: if key not in ["tags", "metadata_entries"] and value is not None:
setattr(db_entity, key, value) setattr(db_entity, key, value)
# Handle tags separately # Handle tags separately
if updated_entity.tags is not None: if updated_entity.tags is not None:
db_entity.tags = [] # Clear existing tags
db.query(EntityTagModel).filter(EntityTagModel.entity_id == entity_id).delete()
db.commit()
for tag_name in updated_entity.tags: for tag_name in updated_entity.tags:
tag = db.query(TagModel).filter(TagModel.name == tag_name).first() tag = db.query(TagModel).filter(TagModel.name == tag_name).first()
if not tag: if not tag:
@ -179,9 +185,14 @@ def update_entity(
db.commit() db.commit()
# Handle attrs separately # Handle attrs separately
if updated_entity.attrs is not None: if updated_entity.metadata_entries is not None:
db_entity.attrs = [] # Clear existing attrs
for attr in updated_entity.attrs: db.query(EntityMetadataModel).filter(
EntityMetadataModel.entity_id == entity_id
).delete()
db.commit()
for attr in updated_entity.metadata_entries:
entity_metadata = EntityMetadataModel( entity_metadata = EntityMetadataModel(
entity_id=db_entity.id, entity_id=db_entity.id,
key=attr.key, key=attr.key,
@ -193,7 +204,7 @@ def update_entity(
data_type=attr.data_type, data_type=attr.data_type,
) )
db.add(entity_metadata) db.add(entity_metadata)
db_entity.attrs.append(entity_metadata) db_entity.metadata_entries.append(entity_metadata)
db.commit() db.commit()
db.refresh(db_entity) db.refresh(db_entity)
@ -202,7 +213,13 @@ def update_entity(
def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity: def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
db_entity = get_entity_by_id(entity_id, db) db_entity = get_entity_by_id(entity_id, db)
db_entity.tags = [] if not db_entity:
raise ValueError(f"Entity with id {entity_id} not found")
# Clear existing tags
db.query(EntityTagModel).filter(EntityTagModel.entity_id == entity_id).delete()
db.commit()
for tag_name in tags: for tag_name in tags:
tag = db.query(TagModel).filter(TagModel.name == tag_name).first() tag = db.query(TagModel).filter(TagModel.name == tag_name).first()
if not tag: if not tag:
@ -216,6 +233,9 @@ def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
source=MetadataSource.PLUGIN_GENERATED, source=MetadataSource.PLUGIN_GENERATED,
) )
db.add(entity_tag) db.add(entity_tag)
db.commit()
db.refresh(db_entity)
return Entity(**db_entity.__dict__)
def update_entity_metadata_entries( def update_entity_metadata_entries(

View File

@ -13,6 +13,7 @@ class MetadataSource(Enum):
class MetadataType(Enum): class MetadataType(Enum):
JSON_DATA = "json" JSON_DATA = "json"
TEXT_DATA = "text" TEXT_DATA = "text"
NUMBER_DATA = "number"
class NewLibraryParam(BaseModel): class NewLibraryParam(BaseModel):
@ -48,8 +49,8 @@ class UpdateEntityParam(BaseModel):
file_last_modified_at: datetime | None = None file_last_modified_at: datetime | None = None
file_type: str | None = None file_type: str | None = None
file_type_group: str | None = None file_type_group: str | None = None
tags: List[str] = [] tags: List[str] | None = None
attrs: List[EntityMetadataParam] = [] metadata_entries: List[EntityMetadataParam] | None = None
class UpdateTagParam(BaseModel): class UpdateTagParam(BaseModel):

View File

@ -485,7 +485,7 @@ def test_add_metadata_entry_to_entity_success(client):
source="plugin_generated", source="plugin_generated",
data_type=MetadataType.TEXT_DATA, data_type=MetadataType.TEXT_DATA,
) )
update_entity_param = UpdateEntityParam(attrs=[metadata_entry]) update_entity_param = UpdateEntityParam(metadata_entries=[metadata_entry])
# Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint
update_response = client.put( update_response = client.put(
@ -549,7 +549,7 @@ def test_patch_entity_metadata_entries(client):
}, },
] ]
update_entity_param = UpdateEntityParam( update_entity_param = UpdateEntityParam(
attrs=[EntityMetadataParam(**entry) for entry in patch_metadata_entries] metadata_entries=[EntityMetadataParam(**entry) for entry in patch_metadata_entries]
) )
# Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint