diff --git a/memos/crud.py b/memos/crud.py index 9c63b18..cbfdf6a 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -153,16 +153,22 @@ def get_entity_by_id(entity_id: int, db: Session) -> Entity | None: def update_entity( entity_id: int, updated_entity: UpdateEntityParam, db: Session ) -> Entity: - db_entity = get_entity_by_id(entity_id, db) + db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first() + + if db_entity is None: + raise ValueError(f"Entity with id {entity_id} not found") # Update the main fields of the entity for key, value in updated_entity.model_dump().items(): - if key not in ["tags", "attrs"] and value is not None: + if key not in ["tags", "metadata_entries"] and value is not None: setattr(db_entity, key, value) # Handle tags separately if updated_entity.tags is not None: - db_entity.tags = [] + # Clear existing tags + db.query(EntityTagModel).filter(EntityTagModel.entity_id == entity_id).delete() + db.commit() + for tag_name in updated_entity.tags: tag = db.query(TagModel).filter(TagModel.name == tag_name).first() if not tag: @@ -179,9 +185,14 @@ def update_entity( db.commit() # Handle attrs separately - if updated_entity.attrs is not None: - db_entity.attrs = [] - for attr in updated_entity.attrs: + if updated_entity.metadata_entries is not None: + # Clear existing attrs + db.query(EntityMetadataModel).filter( + EntityMetadataModel.entity_id == entity_id + ).delete() + db.commit() + + for attr in updated_entity.metadata_entries: entity_metadata = EntityMetadataModel( entity_id=db_entity.id, key=attr.key, @@ -193,7 +204,7 @@ def update_entity( data_type=attr.data_type, ) db.add(entity_metadata) - db_entity.attrs.append(entity_metadata) + db_entity.metadata_entries.append(entity_metadata) db.commit() db.refresh(db_entity) @@ -202,7 +213,13 @@ def update_entity( def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity: db_entity = get_entity_by_id(entity_id, db) - db_entity.tags = [] + if not db_entity: + raise ValueError(f"Entity with id {entity_id} not found") + + # Clear existing tags + db.query(EntityTagModel).filter(EntityTagModel.entity_id == entity_id).delete() + db.commit() + for tag_name in tags: tag = db.query(TagModel).filter(TagModel.name == tag_name).first() if not tag: @@ -216,6 +233,9 @@ def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity: source=MetadataSource.PLUGIN_GENERATED, ) db.add(entity_tag) + db.commit() + db.refresh(db_entity) + return Entity(**db_entity.__dict__) def update_entity_metadata_entries( diff --git a/memos/schemas.py b/memos/schemas.py index bd5e8c0..dd4ade9 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -13,6 +13,7 @@ class MetadataSource(Enum): class MetadataType(Enum): JSON_DATA = "json" TEXT_DATA = "text" + NUMBER_DATA = "number" class NewLibraryParam(BaseModel): @@ -48,8 +49,8 @@ class UpdateEntityParam(BaseModel): file_last_modified_at: datetime | None = None file_type: str | None = None file_type_group: str | None = None - tags: List[str] = [] - attrs: List[EntityMetadataParam] = [] + tags: List[str] | None = None + metadata_entries: List[EntityMetadataParam] | None = None class UpdateTagParam(BaseModel): diff --git a/memos/test_server.py b/memos/test_server.py index f010f71..f072697 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -485,7 +485,7 @@ def test_add_metadata_entry_to_entity_success(client): source="plugin_generated", data_type=MetadataType.TEXT_DATA, ) - update_entity_param = UpdateEntityParam(attrs=[metadata_entry]) + update_entity_param = UpdateEntityParam(metadata_entries=[metadata_entry]) # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint update_response = client.put( @@ -549,7 +549,7 @@ def test_patch_entity_metadata_entries(client): }, ] update_entity_param = UpdateEntityParam( - attrs=[EntityMetadataParam(**entry) for entry in patch_metadata_entries] + metadata_entries=[EntityMetadataParam(**entry) for entry in patch_metadata_entries] ) # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint