From fca387b22d7f574ce81d5542e6460766b67973c2 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 13:55:52 +0800 Subject: [PATCH] feat: support bind plugin by name --- memos/commands.py | 16 ++++++++++------ memos/schemas.py | 31 ++++++++++++++++++++++++++++--- memos/server.py | 21 +++++++++++++++++++-- 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/memos/commands.py b/memos/commands.py index 47242be..744bd9d 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -723,18 +723,22 @@ def create(name: str, webhook_url: str, description: str = ""): @plugin_app.command("bind") def bind( library_id: int = typer.Option(..., "--lib", help="ID of the library"), - plugin_id: int = typer.Option(..., "--plugin", help="ID of the plugin"), + plugin: str = typer.Option(..., "--plugin", help="ID or name of the plugin"), ): + try: + plugin_id = int(plugin) + plugin_param = {"plugin_id": plugin_id} + except ValueError: + plugin_param = {"plugin_name": plugin} + response = httpx.post( f"{BASE_URL}/libraries/{library_id}/plugins", - json={"plugin_id": plugin_id}, + json=plugin_param, ) - if 200 <= response.status_code < 300: + if response.status_code == 204: print("Plugin bound to library successfully") else: - print( - f"Failed to bind plugin to library: {response.status_code} - {response.text}" - ) + print(f"Failed to bind plugin to library: {response.status_code} - {response.text}") @plugin_app.command("unbind") diff --git a/memos/schemas.py b/memos/schemas.py index edd040f..b1b6042 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -1,4 +1,11 @@ -from pydantic import BaseModel, ConfigDict, DirectoryPath, HttpUrl, Field +from pydantic import ( + BaseModel, + ConfigDict, + DirectoryPath, + HttpUrl, + Field, + model_validator, +) from typing import List, Optional, Any, Dict from datetime import datetime from enum import Enum @@ -79,7 +86,18 @@ class NewPluginParam(BaseModel): class NewLibraryPluginParam(BaseModel): - plugin_id: int + plugin_id: Optional[int] = None + plugin_name: Optional[str] = None + + @model_validator(mode="after") + def check_either_id_or_name(self): + plugin_id = self.plugin_id + plugin_name = self.plugin_name + if not (plugin_id or plugin_name): + raise ValueError("Either plugin_id or plugin_name must be provided") + if plugin_id is not None and plugin_name is not None: + raise ValueError("Only one of plugin_id or plugin_name should be provided") + return self class Folder(BaseModel): @@ -214,15 +232,18 @@ class FacetCount(BaseModel): highlighted: str value: str + class FacetStats(BaseModel): total_values: int + class Facet(BaseModel): counts: List[FacetCount] field_name: str sampled: bool stats: FacetStats + class TextMatchInfo(BaseModel): best_field_score: str best_field_weight: int @@ -232,9 +253,11 @@ class TextMatchInfo(BaseModel): tokens_matched: int typo_prefix_score: int + class HybridSearchInfo(BaseModel): rank_fusion_score: float + class SearchHit(BaseModel): document: EntitySearchResult highlight: Dict[str, Any] = {} @@ -243,12 +266,14 @@ class SearchHit(BaseModel): text_match: Optional[int] = None text_match_info: Optional[TextMatchInfo] = None + class RequestParams(BaseModel): collection_name: str first_q: str per_page: int q: str + class SearchResult(BaseModel): facet_counts: List[Facet] found: int @@ -257,4 +282,4 @@ class SearchResult(BaseModel): page: int request_params: RequestParams search_cutoff: bool - search_time_ms: int \ No newline at end of file + search_time_ms: int diff --git a/memos/server.py b/memos/server.py index d48e355..f40982e 100644 --- a/memos/server.py +++ b/memos/server.py @@ -602,12 +602,29 @@ def add_library_plugin( library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db) ): library = crud.get_library_by_id(library_id, db) - if any(plugin.id == new_plugin.plugin_id for plugin in library.plugins): + if library is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Library not found" + ) + + plugin = None + if new_plugin.plugin_id is not None: + plugin = crud.get_plugin_by_id(new_plugin.plugin_id, db) + elif new_plugin.plugin_name is not None: + plugin = crud.get_plugin_by_name(new_plugin.plugin_name, db) + + if plugin is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found" + ) + + if any(p.id == plugin.id for p in library.plugins): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Plugin already exists in the library", ) - crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db) + + crud.add_plugin_to_library(library_id, plugin.id, db) @app.delete(