diff --git a/memos/crud.py b/memos/crud.py index 4b55203..6fae7f9 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -73,7 +73,7 @@ def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entit db.add(db_entity) db.commit() db.refresh(db_entity) - return db_entity + return Entity(**db_entity.__dict__) def get_entity_by_id(entity_id: int, db: Session) -> Entity | None: @@ -153,4 +153,4 @@ def update_entity( setattr(db_entity, key, value) db.commit() db.refresh(db_entity) - return db_entity + return Entity(**db_entity.__dict__) diff --git a/memos/server.py b/memos/server.py index 11c4fdc..87dd5a6 100644 --- a/memos/server.py +++ b/memos/server.py @@ -1,9 +1,11 @@ +import httpx import uvicorn -from fastapi import FastAPI, HTTPException, Depends, status, Query +from fastapi import FastAPI, HTTPException, Depends, status, Query, Request from sqlalchemy.orm import Session from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from typing import List, Annotated +import asyncio from .config import get_database_path import memos.crud as crud @@ -94,9 +96,41 @@ def new_folder( return crud.add_folder(library_id=library.id, folder=folder, db=db) +async def trigger_webhooks(library, entity, request): + async with httpx.AsyncClient() as client: + tasks = [] + for plugin in library.plugins: + if plugin.webhook_url: + location = str( + request.url_for( + "get_entity_by_id", library_id=library.id, entity_id=entity.id + ) + ) + task = client.post( + plugin.webhook_url, + json={"entity": entity.model_dump(mode="json")}, + headers={"Location": location}, + timeout=10.0, # Adding a timeout of 10 seconds + ) + tasks.append(task) + + responses = await asyncio.gather(*tasks, return_exceptions=True) + + for plugin, response in zip(library.plugins, responses): + if isinstance(response, Exception): + print(f"Error triggering webhook for plugin {plugin.id}: {response}") + elif response.status_code >= 400: + print( + f"Error triggering webhook for plugin {plugin.id}: {response.status_code} - {response.text}" + ) + + @app.post("/libraries/{library_id}/entities", response_model=Entity) -def new_entity( - new_entity: NewEntityParam, library_id: int, db: Session = Depends(get_db) +async def new_entity( + new_entity: NewEntityParam, + library_id: int, + request: Request, + db: Session = Depends(get_db), ): library = crud.get_library_by_id(library_id, db) if library is None: @@ -105,6 +139,7 @@ def new_entity( ) entity = crud.create_entity(library_id, new_entity, db) + await trigger_webhooks(library, entity, request) return entity @@ -156,10 +191,11 @@ def get_entity_by_id(library_id: int, entity_id: int, db: Session = Depends(get_ @app.put("/libraries/{library_id}/entities/{entity_id}", response_model=Entity) -def update_entity( +async def update_entity( library_id: int, entity_id: int, updated_entity: UpdateEntityParam, + request: Request, db: Session = Depends(get_db), ): entity = crud.get_entity_by_id(entity_id, db) @@ -169,7 +205,14 @@ def update_entity( detail="Entity not found in the specified library", ) + library = crud.get_library_by_id(library_id, db) + if library is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Library not found" + ) + entity = crud.update_entity(entity_id, updated_entity, db) + await trigger_webhooks(library, entity, request) return entity @@ -206,7 +249,6 @@ def list_plugins(db: Session = Depends(get_db)): return plugins - @app.post("/libraries/{library_id}/plugins", status_code=status.HTTP_204_NO_CONTENT) def add_library_plugin( library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db)