From 8ec89219f39f5f57aa85a118672782c6be8c410e Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 10 Jun 2024 00:29:27 +0800 Subject: [PATCH] feat(entity): update tags and metadata --- memos/crud.py | 43 +++++++++- memos/models.py | 10 +-- memos/schemas.py | 69 ++++++++-------- memos/server.py | 4 +- memos/test_server.py | 187 ++++++++++++++++++++++++++++++++++++++----- 5 files changed, 253 insertions(+), 60 deletions(-) diff --git a/memos/crud.py b/memos/crud.py index 6fae7f9..fd8b150 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -11,6 +11,7 @@ from .schemas import ( NewPluginParam, UpdateEntityParam, NewFolderParam, + MetadataSource, ) from .models import ( LibraryModel, @@ -19,6 +20,9 @@ from .models import ( EntityModel, PluginModel, LibraryPluginModel, + TagModel, + EntityMetadataModel, + EntityTagModel, ) @@ -149,8 +153,45 @@ def update_entity( entity_id: int, updated_entity: UpdateEntityParam, db: Session ) -> Entity: db_entity = get_entity_by_id(entity_id, db) + + # Update the main fields of the entity for key, value in updated_entity.model_dump().items(): - setattr(db_entity, key, value) + if key not in ["tags", "attrs"] and value is not None: + setattr(db_entity, key, value) + + # Handle tags separately + if updated_entity.tags is not None: + db_entity.tags = [] + for tag_name in updated_entity.tags: + tag = db.query(TagModel).filter(TagModel.name == tag_name).first() + if not tag: + tag = TagModel(name=tag_name) + db.add(tag) + db.commit() + db.refresh(tag) + entity_tag = EntityTagModel( + entity_id=db_entity.id, + tag_id=tag.id, + source=MetadataSource.PLUGIN_GENERATED, + ) + db.add(entity_tag) + db.commit() + + # Handle attrs separately + if updated_entity.attrs is not None: + db_entity.attrs = [] + for attr in updated_entity.attrs: + entity_metadata = EntityMetadataModel( + entity_id=db_entity.id, + key=attr.key, + value=attr.value, + source=attr.source if attr.source is not None else None, + source_type=MetadataSource.PLUGIN_GENERATED if attr.source is not None else None, + data_type=attr.data_type, + ) + db.add(entity_metadata) + db_entity.attrs.append(entity_metadata) + db.commit() db.refresh(db_entity) return Entity(**db_entity.__dict__) diff --git a/memos/models.py b/memos/models.py index 249bfda..c106f0e 100644 --- a/memos/models.py +++ b/memos/models.py @@ -69,9 +69,10 @@ class EntityModel(Base): "FolderModel", back_populates="entities" ) metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship( - "EntityMetadataModel" + "EntityMetadataModel", lazy="joined" ) - tags: Mapped[List["TagModel"]] = relationship("EntityTagModel") + tags: Mapped[List["TagModel"]] = relationship("TagModel", secondary="entity_tags", lazy="joined") + class TagModel(Base): @@ -79,7 +80,7 @@ class TagModel(Base): name: Mapped[str] = mapped_column(String, nullable=False) description: Mapped[str | None] = mapped_column(Text, nullable=True) color: Mapped[str | None] = mapped_column(String, nullable=True) - source: Mapped[str | None] = mapped_column(String, nullable=True) + # source: Mapped[str | None] = mapped_column(String, nullable=True) class EntityTagModel(Base): @@ -102,7 +103,7 @@ class EntityMetadataModel(Base): Enum(MetadataSource), nullable=False ) source: Mapped[str | None] = mapped_column(String, nullable=True) - date_type: Mapped[MetadataType] = mapped_column(Enum(MetadataType), nullable=False) + data_type: Mapped[MetadataType] = mapped_column(Enum(MetadataType), nullable=False) entity = relationship("EntityModel", back_populates="metadata_entries") @@ -121,7 +122,6 @@ class LibraryPluginModel(Base): plugin_id: Mapped[int] = mapped_column( Integer, ForeignKey("plugins.id"), nullable=False ) - # Create the database engine with the path from config diff --git a/memos/schemas.py b/memos/schemas.py index 4ca143b..763eb9a 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -34,11 +34,20 @@ class NewEntityParam(BaseModel): folder_id: int +class EntityMetadataParam(BaseModel): + key: str + value: str + source: str + data_type: MetadataType + + class UpdateEntityParam(BaseModel): - size: int - file_created_at: datetime - file_last_modified_at: datetime - file_type: str + size: int | None = None + file_created_at: datetime | None = None + file_last_modified_at: datetime | None = None + file_type: str | None = None + tags: List[str] = [] + attrs: List[EntityMetadataParam] = [] class UpdateTagParam(BaseModel): @@ -50,13 +59,6 @@ class UpdateEntityTagsParam(BaseModel): tags: List[str] = [] -class EntityMetadataParam(BaseModel): - key: str - value: str - source: MetadataSource - data_type: MetadataType - - class UpdateEntityMetadataParam(BaseModel): metadata_entries: List[EntityMetadataParam] @@ -96,6 +98,28 @@ class Library(BaseModel): model_config = ConfigDict(from_attributes=True) +class Tag(BaseModel): + id: int + name: str + description: str | None + color: str | None + created_at: datetime + # source: str + + model_config = ConfigDict(from_attributes=True) + + +class EntityMetadata(BaseModel): + id: int + entity_id: int + key: str + value: str + source: str + data_type: MetadataType + + model_config = ConfigDict(from_attributes=True) + + class Entity(BaseModel): id: int filepath: str @@ -107,27 +131,8 @@ class Entity(BaseModel): last_scan_at: datetime | None folder_id: int library_id: int + tags: List[Tag] = [] + metadata_entries: List[EntityMetadata] = [] model_config = ConfigDict(from_attributes=True) - -class Tag(BaseModel): - id: int - name: str - description: str | None - color: str | None - created_at: datetime - source: str - - model_config = ConfigDict(from_attributes=True) - - -class EntityMetadata(BaseModel): - id: int - entity_id: int - key: str - value: str - source: str - date_type: MetadataType - - model_config = ConfigDict(from_attributes=True) diff --git a/memos/server.py b/memos/server.py index 87dd5a6..122bfd8 100644 --- a/memos/server.py +++ b/memos/server.py @@ -197,6 +197,7 @@ async def update_entity( updated_entity: UpdateEntityParam, request: Request, db: Session = Depends(get_db), + trigger_webhooks_flag: bool = False, ): entity = crud.get_entity_by_id(entity_id, db) if entity is None or entity.library_id != library_id: @@ -212,7 +213,8 @@ async def update_entity( ) entity = crud.update_entity(entity_id, updated_entity, db) - await trigger_webhooks(library, entity, request) + if trigger_webhooks_flag: + await trigger_webhooks(library, entity, request) return entity diff --git a/memos/test_server.py b/memos/test_server.py index 2217b6d..a542165 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -15,6 +15,8 @@ from memos.schemas import ( NewEntityParam, UpdateEntityParam, NewFolderParam, + EntityMetadataParam, + MetadataType, ) from memos.models import Base @@ -72,7 +74,9 @@ def test_new_library(client): duplicate_response = client.post("/libraries", json=library_param.model_dump()) # Check that the response indicates a failure due to duplicate name assert duplicate_response.status_code == 400 - assert duplicate_response.json() == {"detail": "Library with this name already exists"} + assert duplicate_response.json() == { + "detail": "Library with this name already exists" + } def test_list_libraries(client): @@ -265,9 +269,7 @@ def test_get_entity_by_filepath(client): def test_list_entities_in_folder(client): # Setup data: Create a new library and folder - new_library = NewLibraryParam( - name="Library for List Entities Test", folders=[] - ) + new_library = NewLibraryParam(name="Library for List Entities Test", folders=[]) library_response = client.post( "/libraries", json=new_library.model_dump(mode="json") ) @@ -325,12 +327,16 @@ def test_list_entities_in_folder(client): def test_remove_entity(client): # Create a new library new_library = NewLibraryParam(name="Test Library") - library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) + library_response = client.post( + "/libraries", json=new_library.model_dump(mode="json") + ) library_id = library_response.json()["id"] # Create a new folder in the library new_folder = NewFolderParam(path="/tmp") - folder_response = client.post(f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")) + folder_response = client.post( + f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json") + ) folder_id = folder_response.json()["id"] # Create a new entity to be deleted @@ -360,7 +366,9 @@ def test_remove_entity(client): # Test for entity not found in the specified library invalid_delete_response = client.delete(f"/libraries/{library_id}/entities/9999") assert invalid_delete_response.status_code == 404 - assert invalid_delete_response.json() == {"detail": "Entity not found in the specified library"} + assert invalid_delete_response.json() == { + "detail": "Entity not found in the specified library" + } def test_add_folder_to_library(client): @@ -374,12 +382,16 @@ def test_add_folder_to_library(client): # Create a new library new_library = NewLibraryParam(name="Test Library") - library_response = client.post("/libraries", json=new_library.model_dump(mode="json")) + library_response = client.post( + "/libraries", json=new_library.model_dump(mode="json") + ) library_id = library_response.json()["id"] # Add a new folder to the library new_folder = NewFolderParam(path="/tmp/new_folder") - folder_response = client.post(f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")) + folder_response = client.post( + f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json") + ) assert folder_response.status_code == 200 assert folder_response.json()["path"] == "/tmp/new_folder" @@ -391,41 +403,174 @@ def test_add_folder_to_library(client): assert "/tmp/new_folder" in folder_paths # Test for adding a folder that already exists - duplicate_folder_response = client.post(f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")) + duplicate_folder_response = client.post( + f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json") + ) assert duplicate_folder_response.status_code == 400 - assert duplicate_folder_response.json() == {"detail": "Folder already exists in the library"} + assert duplicate_folder_response.json() == { + "detail": "Folder already exists in the library" + } # Test for adding a folder to a non-existent library - invalid_folder_response = client.post(f"/libraries/9999/folders", json=new_folder.model_dump(mode="json")) + invalid_folder_response = client.post( + f"/libraries/9999/folders", json=new_folder.model_dump(mode="json") + ) assert invalid_folder_response.status_code == 404 assert invalid_folder_response.json() == {"detail": "Library not found"} def test_new_plugin(client): - new_plugin = NewPluginParam(name="Test Plugin", description="A test plugin", webhook_url="http://example.com/webhook") - + new_plugin = NewPluginParam( + name="Test Plugin", + description="A test plugin", + webhook_url="http://example.com/webhook", + ) + # Make a POST request to the /plugins endpoint response = client.post("/plugins", json=new_plugin.model_dump(mode="json")) - + # Check that the response is successful assert response.status_code == 200 - + # Check the response data plugin_data = response.json() assert plugin_data["name"] == "Test Plugin" assert plugin_data["description"] == "A test plugin" assert plugin_data["webhook_url"] == "http://example.com/webhook" - + # Test for duplicate plugin name - duplicate_response = client.post("/plugins", json=new_plugin.model_dump(mode="json")) + duplicate_response = client.post( + "/plugins", json=new_plugin.model_dump(mode="json") + ) # Check that the response indicates a failure due to duplicate name assert duplicate_response.status_code == 400 - assert duplicate_response.json() == {"detail": "Plugin with this name already exists"} + assert duplicate_response.json() == { + "detail": "Plugin with this name already exists" + } # Test for another duplicate plugin name - another_duplicate_response = client.post("/plugins", json=new_plugin.model_dump(mode="json")) + another_duplicate_response = client.post( + "/plugins", json=new_plugin.model_dump(mode="json") + ) # Check that the response indicates a failure due to duplicate name assert another_duplicate_response.status_code == 400 - assert another_duplicate_response.json() == {"detail": "Plugin with this name already exists"} + assert another_duplicate_response.json() == { + "detail": "Plugin with this name already exists" + } +def test_update_entity_with_tags(client): + # Create a new library + new_library = NewLibraryParam(name="Test Library") + library_response = client.post( + "/libraries", json=new_library.model_dump(mode="json") + ) + assert library_response.status_code == 200 + library_id = library_response.json()["id"] + + # Create a new folder in the library + new_folder = NewFolderParam(path="/tmp/new_folder") + folder_response = client.post( + f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json") + ) + assert folder_response.status_code == 200 + folder_id = folder_response.json()["id"] + + # Create a new entity in the folder + new_entity = NewEntityParam( + filename="test_file.txt", + filepath="/tmp/new_folder/test_file.txt", + size=1234, + file_created_at="2023-01-01T00:00:00", + file_last_modified_at="2023-01-01T00:00:00", + file_type="text/plain", + folder_id=folder_id, + ) + entity_response = client.post( + f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json") + ) + assert entity_response.status_code == 200 + entity_id = entity_response.json()["id"] + + # Update the entity with tags + update_entity_param = UpdateEntityParam(tags=["tag1", "tag2"]) + + # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint + update_response = client.put( + f"/libraries/{library_id}/entities/{entity_id}", + json=update_entity_param.model_dump(mode="json"), + ) + + # Check that the response is successful + assert update_response.status_code == 200 + + # Check the response data + updated_entity_data = update_response.json() + assert "tags" in updated_entity_data + assert len(updated_entity_data["tags"]) == 2 + assert "tag1" in [tag["name"] for tag in updated_entity_data["tags"]] + assert "tag2" in [tag["name"] for tag in updated_entity_data["tags"]] + + +def test_add_metadata_entry_to_entity_success(client): + # Create a new library + new_library = NewLibraryParam(name="Test Library for Metadata") + library_response = client.post( + "/libraries", json=new_library.model_dump(mode="json") + ) + assert library_response.status_code == 200 + library_id = library_response.json()["id"] + + # Create a new folder in the library + new_folder = NewFolderParam(path="/tmp") + folder_response = client.post( + f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json") + ) + assert folder_response.status_code == 200 + folder_id = folder_response.json()["id"] + + # Create a new entity in the folder + new_entity = NewEntityParam( + filename="metadata_test_file.txt", + filepath="/tmp/metadata_folder/metadata_test_file.txt", + size=5678, + file_created_at="2023-01-01T00:00:00", + file_last_modified_at="2023-01-01T00:00:00", + file_type="text/plain", + folder_id=folder_id, + ) + entity_response = client.post( + f"/libraries/{library_id}/entities", json=new_entity.model_dump(mode="json") + ) + assert entity_response.status_code == 200 + entity_id = entity_response.json()["id"] + + # Add metadata entry to the entity + metadata_entry = EntityMetadataParam( + key="author", + value="John Doe", + source="plugin_generated", + data_type=MetadataType.ATTRIBUTE, + ) + update_entity_param = UpdateEntityParam(attrs=[metadata_entry]) + + # Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint + update_response = client.put( + f"/libraries/{library_id}/entities/{entity_id}", + json=update_entity_param.model_dump(mode="json"), + ) + + # Check that the response is successful + assert update_response.status_code == 200 + + # Check the response data + updated_entity_data = update_response.json() + assert "metadata_entries" in updated_entity_data + assert len(updated_entity_data["metadata_entries"]) == 1 + assert updated_entity_data["metadata_entries"][0]["key"] == "author" + assert updated_entity_data["metadata_entries"][0]["value"] == "John Doe" + assert updated_entity_data["metadata_entries"][0]["source"] == "plugin_generated" + assert ( + updated_entity_data["metadata_entries"][0]["data_type"] + == MetadataType.ATTRIBUTE.value + )