mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-10 04:57:12 +00:00
feat(tag): add a patch tag api
This commit is contained in:
parent
b1993720e1
commit
a28fa41153
@ -73,7 +73,7 @@ def add_folders(library_id: int, folders: NewFoldersParam, db: Session) -> Libra
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_folder)
|
db.refresh(db_folder)
|
||||||
db_folders.append(Folder(id=db_folder.id, path=db_folder.path))
|
db_folders.append(Folder(id=db_folder.id, path=db_folder.path))
|
||||||
|
|
||||||
db_library = db.query(LibraryModel).filter(LibraryModel.id == library_id).first()
|
db_library = db.query(LibraryModel).filter(LibraryModel.id == library_id).first()
|
||||||
return Library(**db_library.__dict__)
|
return Library(**db_library.__dict__)
|
||||||
|
|
||||||
@ -246,6 +246,32 @@ def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
|
|||||||
return Entity(**db_entity.__dict__)
|
return Entity(**db_entity.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
def add_new_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
|
||||||
|
db_entity = get_entity_by_id(entity_id, db)
|
||||||
|
if not db_entity:
|
||||||
|
raise ValueError(f"Entity with id {entity_id} not found")
|
||||||
|
|
||||||
|
existing_tags = set(tag.name for tag in db_entity.tags)
|
||||||
|
new_tags = set(tags) - existing_tags
|
||||||
|
|
||||||
|
for tag_name in new_tags:
|
||||||
|
tag = db.query(TagModel).filter(TagModel.name == tag_name).first()
|
||||||
|
if not tag:
|
||||||
|
tag = TagModel(name=tag_name)
|
||||||
|
db.add(tag)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(tag)
|
||||||
|
entity_tag = EntityTagModel(
|
||||||
|
entity_id=db_entity.id,
|
||||||
|
tag_id=tag.id,
|
||||||
|
source=MetadataSource.PLUGIN_GENERATED,
|
||||||
|
)
|
||||||
|
db.add(entity_tag)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_entity)
|
||||||
|
return Entity(**db_entity.__dict__)
|
||||||
|
|
||||||
|
|
||||||
def update_entity_metadata_entries(
|
def update_entity_metadata_entries(
|
||||||
entity_id: int, updated_metadata: List[EntityMetadataParam], db: Session
|
entity_id: int, updated_metadata: List[EntityMetadataParam], db: Session
|
||||||
) -> Entity:
|
) -> Entity:
|
||||||
|
@ -152,13 +152,17 @@ def new_folders(
|
|||||||
return crud.add_folders(library_id=library.id, folders=folders, db=db)
|
return crud.add_folders(library_id=library.id, folders=folders, db=db)
|
||||||
|
|
||||||
|
|
||||||
async def trigger_webhooks(library: Library, entity: Entity, request: Request, plugins: List[int] = None):
|
async def trigger_webhooks(
|
||||||
|
library: Library, entity: Entity, request: Request, plugins: List[int] = None
|
||||||
|
):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
tasks = []
|
tasks = []
|
||||||
for plugin in library.plugins:
|
for plugin in library.plugins:
|
||||||
if plugins is None or plugin.id in plugins:
|
if plugins is None or plugin.id in plugins:
|
||||||
if plugin.webhook_url:
|
if plugin.webhook_url:
|
||||||
location = str(request.url_for("get_entity_by_id", entity_id=entity.id))
|
location = str(
|
||||||
|
request.url_for("get_entity_by_id", entity_id=entity.id)
|
||||||
|
)
|
||||||
task = client.post(
|
task = client.post(
|
||||||
plugin.webhook_url,
|
plugin.webhook_url,
|
||||||
json=entity.model_dump(mode="json"),
|
json=entity.model_dump(mode="json"),
|
||||||
@ -172,7 +176,9 @@ async def trigger_webhooks(library: Library, entity: Entity, request: Request, p
|
|||||||
for plugin, response in zip(library.plugins, responses):
|
for plugin, response in zip(library.plugins, responses):
|
||||||
if plugins is None or plugin.id in plugins:
|
if plugins is None or plugin.id in plugins:
|
||||||
if isinstance(response, Exception):
|
if isinstance(response, Exception):
|
||||||
print(f"Error triggering webhook for plugin {plugin.id}: {response}")
|
print(
|
||||||
|
f"Error triggering webhook for plugin {plugin.id}: {response}"
|
||||||
|
)
|
||||||
elif response.status_code >= 400:
|
elif response.status_code >= 400:
|
||||||
print(
|
print(
|
||||||
f"Error triggering webhook for plugin {plugin.id}: {response.status_code} - {response.text}"
|
f"Error triggering webhook for plugin {plugin.id}: {response.status_code} - {response.text}"
|
||||||
@ -285,7 +291,7 @@ async def update_entity(
|
|||||||
|
|
||||||
if updated_entity:
|
if updated_entity:
|
||||||
entity = crud.update_entity(entity_id, updated_entity, db)
|
entity = crud.update_entity(entity_id, updated_entity, db)
|
||||||
|
|
||||||
if trigger_webhooks_flag:
|
if trigger_webhooks_flag:
|
||||||
library = crud.get_library_by_id(entity.library_id, db)
|
library = crud.get_library_by_id(entity.library_id, db)
|
||||||
if library is None:
|
if library is None:
|
||||||
@ -404,8 +410,21 @@ async def search_entities(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.patch("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"])
|
|
||||||
@app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"])
|
@app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"])
|
||||||
|
def replace_entity_tags(
|
||||||
|
entity_id: int, update_tags: UpdateEntityTagsParam, db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
entity = crud.get_entity_by_id(entity_id, db)
|
||||||
|
if entity is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Entity not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return crud.update_entity_tags(entity_id, update_tags.tags, db)
|
||||||
|
|
||||||
|
|
||||||
|
@app.patch("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"])
|
||||||
def patch_entity_tags(
|
def patch_entity_tags(
|
||||||
entity_id: int, update_tags: UpdateEntityTagsParam, db: Session = Depends(get_db)
|
entity_id: int, update_tags: UpdateEntityTagsParam, db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
@ -416,9 +435,7 @@ def patch_entity_tags(
|
|||||||
detail="Entity not found",
|
detail="Entity not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the CRUD function to update the tags
|
return crud.add_new_tags(entity_id, update_tags.tags, db)
|
||||||
entity = crud.update_entity_tags(entity_id, update_tags.tags, db)
|
|
||||||
return entity
|
|
||||||
|
|
||||||
|
|
||||||
@app.patch("/entities/{entity_id}/metadata", response_model=Entity, tags=["entity"])
|
@app.patch("/entities/{entity_id}/metadata", response_model=Entity, tags=["entity"])
|
||||||
|
@ -18,6 +18,7 @@ from memos.schemas import (
|
|||||||
NewFoldersParam,
|
NewFoldersParam,
|
||||||
EntityMetadataParam,
|
EntityMetadataParam,
|
||||||
MetadataType,
|
MetadataType,
|
||||||
|
UpdateEntityTagsParam,
|
||||||
UpdateEntityMetadataParam,
|
UpdateEntityMetadataParam,
|
||||||
)
|
)
|
||||||
from memos.models import Base
|
from memos.models import Base
|
||||||
@ -219,9 +220,7 @@ def test_update_entity(client):
|
|||||||
json=updated_entity.model_dump(mode="json"),
|
json=updated_entity.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
assert invalid_update_response.status_code == 404
|
assert invalid_update_response.status_code == 404
|
||||||
assert invalid_update_response.json() == {
|
assert invalid_update_response.json() == {"detail": "Entity not found"}
|
||||||
"detail": "Entity not found"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Test for getting an entity by filepath
|
# Test for getting an entity by filepath
|
||||||
@ -380,7 +379,10 @@ def test_add_folder_to_library(client):
|
|||||||
f"/libraries/{library_id}/folders", json=new_folders.model_dump(mode="json")
|
f"/libraries/{library_id}/folders", json=new_folders.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
assert folder_response.status_code == 200
|
assert folder_response.status_code == 200
|
||||||
assert any(folder["path"] == tmp_folder_path for folder in folder_response.json()["folders"])
|
assert any(
|
||||||
|
folder["path"] == tmp_folder_path
|
||||||
|
for folder in folder_response.json()["folders"]
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the folder is added
|
# Verify the folder is added
|
||||||
library_response = client.get(f"/libraries/{library_id}")
|
library_response = client.get(f"/libraries/{library_id}")
|
||||||
@ -469,6 +471,59 @@ def test_update_entity_with_tags(client):
|
|||||||
assert "tag2" in [tag["name"] for tag in updated_entity_data["tags"]]
|
assert "tag2" in [tag["name"] for tag in updated_entity_data["tags"]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_tags_to_entity(client):
|
||||||
|
library_id, _, entity_id = setup_library_with_entity(client)
|
||||||
|
|
||||||
|
# Initial tags
|
||||||
|
initial_tags = ["tag1", "tag2"]
|
||||||
|
update_entity_param = UpdateEntityTagsParam(tags=initial_tags)
|
||||||
|
|
||||||
|
# Make a PUT request to add initial tags
|
||||||
|
initial_update_response = client.put(
|
||||||
|
f"/entities/{entity_id}/tags",
|
||||||
|
json=update_entity_param.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the initial update is successful
|
||||||
|
assert initial_update_response.status_code == 200
|
||||||
|
initial_entity_data = initial_update_response.json()
|
||||||
|
assert len(initial_entity_data["tags"]) == 2
|
||||||
|
assert set([tag["name"] for tag in initial_entity_data["tags"]]) == set(
|
||||||
|
initial_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
# New tags to patch
|
||||||
|
new_tags = ["tag3", "tag4"]
|
||||||
|
patch_entity_param = UpdateEntityTagsParam(tags=new_tags)
|
||||||
|
|
||||||
|
# Make a PATCH request to add new tags
|
||||||
|
patch_response = client.patch(
|
||||||
|
f"/entities/{entity_id}/tags",
|
||||||
|
json=patch_entity_param.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the patch response is successful
|
||||||
|
assert patch_response.status_code == 200
|
||||||
|
|
||||||
|
# Check the response data
|
||||||
|
patched_entity_data = patch_response.json()
|
||||||
|
assert "tags" in patched_entity_data
|
||||||
|
assert len(patched_entity_data["tags"]) == 4
|
||||||
|
assert set([tag["name"] for tag in patched_entity_data["tags"]]) == set(
|
||||||
|
initial_tags + new_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the tags were actually added by making a GET request
|
||||||
|
get_response = client.get(f"/libraries/{library_id}/entities/{entity_id}")
|
||||||
|
assert get_response.status_code == 200
|
||||||
|
get_entity_data = get_response.json()
|
||||||
|
assert "tags" in get_entity_data
|
||||||
|
assert len(get_entity_data["tags"]) == 4
|
||||||
|
assert set([tag["name"] for tag in get_entity_data["tags"]]) == set(
|
||||||
|
initial_tags + new_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_add_metadata_entry_to_entity_success(client):
|
def test_add_metadata_entry_to_entity_success(client):
|
||||||
library_id, _, entity_id = setup_library_with_entity(client)
|
library_id, _, entity_id = setup_library_with_entity(client)
|
||||||
|
|
||||||
@ -543,7 +598,9 @@ def test_patch_entity_metadata_entries(client):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
update_entity_param = UpdateEntityParam(
|
update_entity_param = UpdateEntityParam(
|
||||||
metadata_entries=[EntityMetadataParam(**entry) for entry in patch_metadata_entries]
|
metadata_entries=[
|
||||||
|
EntityMetadataParam(**entry) for entry in patch_metadata_entries
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint
|
# Make a PUT request to the /libraries/{library_id}/entities/{entity_id} endpoint
|
||||||
|
Loading…
x
Reference in New Issue
Block a user