diff --git a/memos/crud.py b/memos/crud.py index fd8b150..9c63b18 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -12,6 +12,7 @@ from .schemas import ( UpdateEntityParam, NewFolderParam, MetadataSource, + EntityMetadataParam, ) from .models import ( LibraryModel, @@ -186,7 +187,9 @@ def update_entity( key=attr.key, value=attr.value, source=attr.source if attr.source is not None else None, - source_type=MetadataSource.PLUGIN_GENERATED if attr.source is not None else None, + source_type=( + MetadataSource.PLUGIN_GENERATED if attr.source is not None else None + ), data_type=attr.data_type, ) db.add(entity_metadata) @@ -195,3 +198,70 @@ def update_entity( db.commit() db.refresh(db_entity) return Entity(**db_entity.__dict__) + + +def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity: + db_entity = get_entity_by_id(entity_id, db) + db_entity.tags = [] + for tag_name in tags: + tag = db.query(TagModel).filter(TagModel.name == tag_name).first() + if not tag: + tag = TagModel(name=tag_name) + db.add(tag) + db.commit() + db.refresh(tag) + entity_tag = EntityTagModel( + entity_id=db_entity.id, + tag_id=tag.id, + source=MetadataSource.PLUGIN_GENERATED, + ) + db.add(entity_tag) + + +def update_entity_metadata_entries( + entity_id: int, updated_metadata: List[EntityMetadataParam], db: Session +) -> Entity: + db_entity = get_entity_by_id(entity_id, db) + + existing_metadata_entries = ( + db.query(EntityMetadataModel) + .filter(EntityMetadataModel.entity_id == db_entity.id) + .all() + ) + + existing_metadata_dict = {entry.key: entry for entry in existing_metadata_entries} + + for metadata in updated_metadata: + if metadata.key in existing_metadata_dict: + existing_metadata = existing_metadata_dict[metadata.key] + existing_metadata.value = metadata.value + existing_metadata.source = ( + metadata.source + if metadata.source is not None + else existing_metadata.source + ) + existing_metadata.source_type = ( + MetadataSource.PLUGIN_GENERATED + if metadata.source is not None + else existing_metadata.source_type + ) + existing_metadata.data_type = metadata.data_type + else: + entity_metadata = EntityMetadataModel( + entity_id=db_entity.id, + key=metadata.key, + value=metadata.value, + source=metadata.source if metadata.source is not None else None, + source_type=( + MetadataSource.PLUGIN_GENERATED + if metadata.source is not None + else None + ), + data_type=metadata.data_type, + ) + db.add(entity_metadata) + db_entity.metadata_entries.append(entity_metadata) + + db.commit() + db.refresh(db_entity) + return Entity(**db_entity.__dict__) diff --git a/memos/server.py b/memos/server.py index 122bfd8..4a219ff 100644 --- a/memos/server.py +++ b/memos/server.py @@ -20,6 +20,8 @@ from .schemas import ( UpdateEntityParam, NewPluginParam, NewLibraryPluginParam, + UpdateEntityTagsParam, + UpdateEntityMetadataParam, ) engine = create_engine(f"sqlite:///{get_database_path()}") @@ -218,6 +220,46 @@ async def update_entity( return entity +@app.patch("/libraries/{library_id}/entities/{entity_id}/tags", response_model=Entity) +@app.put("/libraries/{library_id}/entities/{entity_id}/tags", response_model=Entity) +def patch_entity_tags( + library_id: int, + entity_id: int, + update_tags: UpdateEntityTagsParam, + db: Session = Depends(get_db) +): + entity = crud.get_entity_by_id(entity_id, db) + if entity is None or entity.library_id != library_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Entity not found in the specified library", + ) + + # Use the CRUD function to update the tags + entity = crud.update_entity_tags(entity_id, update_tags.tags, db) + return entity + + +@app.patch("/libraries/{library_id}/entities/{entity_id}/metadata", response_model=Entity) +def patch_entity_metadata( + library_id: int, + entity_id: int, + update_metadata: UpdateEntityMetadataParam, + db: Session = Depends(get_db) +): + entity = crud.get_entity_by_id(entity_id, db) + if entity is None or entity.library_id != library_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Entity not found in the specified library", + ) + + # Use the CRUD function to update the metadata entries + entity = crud.update_entity_metadata_entries(entity_id, update_metadata.metadata_entries, db) + return entity + + + @app.delete( "/libraries/{library_id}/entities/{entity_id}", status_code=status.HTTP_204_NO_CONTENT, diff --git a/memos/test_server.py b/memos/test_server.py index a542165..3d5cc56 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -17,6 +17,7 @@ from memos.schemas import ( NewFolderParam, EntityMetadataParam, MetadataType, + UpdateEntityMetadataParam, ) from memos.models import Base @@ -512,7 +513,7 @@ def test_update_entity_with_tags(client): assert "tag2" in [tag["name"] for tag in updated_entity_data["tags"]] -def test_add_metadata_entry_to_entity_success(client): +def setup_library_with_entity(client): # Create a new library new_library = NewLibraryParam(name="Test Library for Metadata") library_response = client.post( @@ -545,6 +546,13 @@ def test_add_metadata_entry_to_entity_success(client): assert entity_response.status_code == 200 entity_id = entity_response.json()["id"] + return library_id, folder_id, entity_id + + +def test_add_metadata_entry_to_entity_success(client): + + library_id, _, entity_id = setup_library_with_entity(client) + # Add metadata entry to the entity metadata_entry = EntityMetadataParam( key="author", @@ -574,3 +582,144 @@ def test_add_metadata_entry_to_entity_success(client): updated_entity_data["metadata_entries"][0]["data_type"] == MetadataType.ATTRIBUTE.value ) + + +def test_update_entity_tags(client): + library_id, _, entity_id = setup_library_with_entity(client) + + # Add tags to the entity + tags = ["tag1", "tag2", "tag3"] + update_entity_param = UpdateEntityParam(tags=tags) + + # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint + update_response = client.put( + f"/libraries/{library_id}/entities/{entity_id}", + json=update_entity_param.model_dump(mode="json"), + ) + + # Check that the response is successful + assert update_response.status_code == 200 + + # Check the response data + updated_entity_data = update_response.json() + assert "tags" in updated_entity_data + assert sorted([t["name"] for t in updated_entity_data["tags"]]) == sorted( + tags, key=str + ) + + +def test_patch_entity_metadata_entries(client): + library_id, _, entity_id = setup_library_with_entity(client) + + # Patch metadata entries of the entity + patch_metadata_entries = [ + { + "key": "author", + "value": "Jane Smith", + "source": "user_generated", + "data_type": MetadataType.ATTRIBUTE.value, + }, + { + "key": "year", + "value": "2023", + "source": "user_generated", + "data_type": MetadataType.ATTRIBUTE.value, + }, + ] + update_entity_param = UpdateEntityParam( + attrs=[ + EntityMetadataParam(**entry) for entry in patch_metadata_entries + ] + ) + + # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint + patch_response = client.put( + f"/libraries/{library_id}/entities/{entity_id}", + json=update_entity_param.model_dump(mode="json"), + ) + + # Check that the response is successful + assert patch_response.status_code == 200 + + # Check the response data + patched_entity_data = patch_response.json() + assert "metadata_entries" in patched_entity_data + assert len(patched_entity_data["metadata_entries"]) == 2 + assert patched_entity_data["metadata_entries"][0]["key"] == "author" + assert patched_entity_data["metadata_entries"][0]["value"] == "Jane Smith" + assert patched_entity_data["metadata_entries"][0]["source"] == "user_generated" + assert ( + patched_entity_data["metadata_entries"][0]["data_type"] + == MetadataType.ATTRIBUTE.value + ) + assert patched_entity_data["metadata_entries"][1]["key"] == "year" + assert patched_entity_data["metadata_entries"][1]["value"] == "2023" + assert patched_entity_data["metadata_entries"][1]["source"] == "user_generated" + assert ( + patched_entity_data["metadata_entries"][1]["data_type"] + == MetadataType.ATTRIBUTE.value + ) + + # Update the "author" attribute of the entity + updated_metadata_entries = [ + { + "key": "author", + "value": "John Doe", + "source": "user_generated", + "data_type": MetadataType.ATTRIBUTE.value, + } + ] + update_entity_param = UpdateEntityMetadataParam( + metadata_entries=[ + EntityMetadataParam(**entry) for entry in updated_metadata_entries + ] + ) + + # Make a PATCH request to the /libraries/{library_id}/entities/{entity_id}/metadata endpoint + update_response = client.patch( + f"/libraries/{library_id}/entities/{entity_id}/metadata", + json=update_entity_param.model_dump(mode="json"), + ) + + # Check that the response is successful + assert update_response.status_code == 200 + + # Check the response data + updated_entity_data = update_response.json() + assert "metadata_entries" in updated_entity_data + assert any( + entry["key"] == "author" and entry["value"] == "John Doe" + for entry in updated_entity_data["metadata_entries"] + ) + + # Add a new attribute "media_type" with value "book" + new_metadata_entry = { + "key": "media_type", + "value": "book", + "source": "user_generated", + "data_type": MetadataType.ATTRIBUTE.value, + } + updated_metadata_entries.append(new_metadata_entry) + + update_entity_param = UpdateEntityMetadataParam( + metadata_entries=[ + EntityMetadataParam(**entry) for entry in updated_metadata_entries + ] + ) + + # Make a PATCH request to the /libraries/{library_id}/entities/{entity_id}/metadata endpoint + update_response = client.patch( + f"/libraries/{library_id}/entities/{entity_id}/metadata", + json=update_entity_param.model_dump(mode="json"), + ) + + # Check that the response is successful + assert update_response.status_code == 200 + + # Check the response data + updated_entity_data = update_response.json() + assert "metadata_entries" in updated_entity_data + assert any( + entry["key"] == "media_type" and entry["value"] == "book" + for entry in updated_entity_data["metadata_entries"] + )