diff --git a/memos/config.py b/memos/config.py index b160337..3a80a78 100644 --- a/memos/config.py +++ b/memos/config.py @@ -7,13 +7,18 @@ class VLMSettings(BaseModel): enabled: bool = False modelname: str = "internvl-1.5" endpoint: str = "http://localhost:11434" + +class OCRSettings(BaseModel): + enabled: bool = True + endpoint: str = "http://localhost:5555/predict" token: str = "" - concurrency: int = 8 + concurrency: int = 4 + class Settings(BaseSettings): model_config = SettingsConfigDict( yaml_file=str(Path.home() / ".memos" / "config.yaml"), - yaml_file_encoding="utf-8" + yaml_file_encoding="utf-8", ) base_dir: str = str(Path.home() / ".memos") @@ -28,6 +33,10 @@ class Settings(BaseSettings): # VLM plugin settings vlm: VLMSettings = VLMSettings() + # OCR plugin settings + ocr: OCRSettings = OCRSettings() + + settings = Settings() # Define the default database path @@ -38,4 +47,4 @@ TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name # Function to get the database path from environment variable or default def get_database_path(): - return settings.database_path \ No newline at end of file + return settings.database_path diff --git a/memos/models.py b/memos/models.py index 9f1727b..94b3940 100644 --- a/memos/models.py +++ b/memos/models.py @@ -148,6 +148,25 @@ class LibraryPluginModel(Base): ) +def initialize_default_plugins(session): + default_plugins = [ + PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"), + PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"), + ] + + for plugin in default_plugins: + existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first() + if not existing_plugin: + session.add(plugin) + + session.commit() + # Create the database engine with the path from config engine = create_engine(f"sqlite:///{get_database_path()}") Base.metadata.create_all(engine) + +# Initialize default plugins +from sqlalchemy.orm import sessionmaker +Session = sessionmaker(bind=engine) +with Session() as session: + initialize_default_plugins(session) \ No newline at end of file diff --git a/memos/plugins/ocr/README.md b/memos/plugins/ocr/README.md new file mode 100644 index 0000000..65dae4a --- /dev/null +++ b/memos/plugins/ocr/README.md @@ -0,0 +1,49 @@ +# OCR Plugin + +This is a README file for the OCR plugin. This plugin uses the `RapidOCR` library to perform OCR (Optical Character Recognition) on image files and updates the metadata of the entity with the OCR results. + +## How to Run + +To run this OCR plugin, follow the steps below: + +1. **Install the required dependencies:** + + ```bash + pip install -r requirements.txt + ``` + +2. **Run the FastAPI application:** + + You can run the FastAPI application using `uvicorn`. Make sure you are in the directory where `main.py` is located. + + ```bash + uvicorn main:app --host 0.0.0.0 --port 8000 + ``` + +3. **Integration with memos:** + + ```sh + $ python -m memos.commands plugin create ocr http://localhost:8000 + Plugin created successfully + ``` + + ```sh + $ python -m memos.commands plugin ls + + ID Name Description Webhook URL + 1 ocr http://localhost:8000/ + ``` + + ```sh + $ python -m memos.commands plugin bind --lib 1 --plugin 1 + Plugin bound to library successfully + ``` + +## Endpoints + +- `GET /`: Health check endpoint. Returns `{"healthy": True}` if the service is running. +- `POST /`: OCR endpoint. Accepts an `Entity` object and a `Location` header. Performs OCR on the image file and updates the entity's metadata with the OCR results. + +## Metadata + +The OCR results are stored in the metadata field named `ocr_result` with the following structure: diff --git a/memos/plugins/ocr/__init__.py b/memos/plugins/ocr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/ocr/fonts/simfang.ttf b/memos/plugins/ocr/fonts/simfang.ttf similarity index 100% rename from plugins/ocr/fonts/simfang.ttf rename to memos/plugins/ocr/fonts/simfang.ttf diff --git a/plugins/ocr/main.py b/memos/plugins/ocr/main.py similarity index 76% rename from plugins/ocr/main.py rename to memos/plugins/ocr/main.py index d01f4a4..073e8bb 100644 --- a/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -8,17 +8,20 @@ import io import os from PIL import Image -from fastapi import FastAPI, Request, HTTPException +from fastapi import APIRouter, FastAPI, Request, HTTPException from memos.schemas import Entity, MetadataType METADATA_FIELD_NAME = "ocr_result" PLUGIN_NAME = "ocr" -app = FastAPI() - +router = APIRouter( + tags=[PLUGIN_NAME], + responses={404: {"description": "Not found"}} +) endpoint = None token = None -semaphore = asyncio.Semaphore(4) +concurrency = None +semaphore = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -50,6 +53,7 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N return response.json() +# Modify the predict function to use semaphore async def predict(img_path): image_base64 = image2base64(img_path) if not image_base64: @@ -59,19 +63,18 @@ async def predict(img_path): headers = {} if token: headers["Authorization"] = f"Bearer {token}" - ocr_result = await fetch(endpoint, client, image_base64, headers=headers) + async with semaphore: + ocr_result = await fetch(endpoint, client, image_base64, headers=headers) return ocr_result -app = FastAPI() - - -@app.get("/") +@router.get("/") async def read_root(): return {"healthy": True} -@app.post("/") +@router.post("", include_in_schema=False) +@router.post("/") async def ocr(entity: Entity, request: Request): if not entity.file_type_group == "image": return {METADATA_FIELD_NAME: "{}"} @@ -123,27 +126,45 @@ async def ocr(entity: Entity, request: Request): } +def init_plugin(config): + global endpoint, token, concurrency, semaphore + endpoint = config.endpoint + token = config.token + concurrency = config.concurrency + semaphore = asyncio.Semaphore(concurrency) + + print(f"Endpoint: {endpoint}") + print(f"Token: {token}") + print(f"Concurrency: {concurrency}") + + if __name__ == "__main__": import uvicorn import argparse + from fastapi import FastAPI parser = argparse.ArgumentParser(description="OCR Plugin") parser.add_argument( "--endpoint", type=str, - required=True, + default="http://localhost:8080", help="The endpoint URL for the OCR service", ) parser.add_argument( - "--token", type=str, required=False, help="The token for authentication" + "--token", type=str, default="", help="The token for authentication" + ) + parser.add_argument( + "--concurrency", type=int, default=4, help="The concurrency level" ) parser.add_argument( "--port", type=int, default=8000, help="The port number to run the server on" ) args = parser.parse_args() - endpoint = args.endpoint - token = args.token - port = args.port - uvicorn.run(app, host="0.0.0.0", port=port) + init_plugin(args) + + app = FastAPI() + app.include_router(router) + + uvicorn.run(app, host="0.0.0.0", port=args.port) \ No newline at end of file diff --git a/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx b/memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx similarity index 100% rename from plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx rename to memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx diff --git a/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx b/memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx similarity index 100% rename from plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx rename to memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx diff --git a/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx b/memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx similarity index 100% rename from plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx rename to memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx diff --git a/plugins/ocr/ppocr-gpu.yaml b/memos/plugins/ocr/ppocr-gpu.yaml similarity index 100% rename from plugins/ocr/ppocr-gpu.yaml rename to memos/plugins/ocr/ppocr-gpu.yaml diff --git a/plugins/ocr/ppocr.yaml b/memos/plugins/ocr/ppocr.yaml similarity index 100% rename from plugins/ocr/ppocr.yaml rename to memos/plugins/ocr/ppocr.yaml diff --git a/plugins/ocr/requirements.txt b/memos/plugins/ocr/requirements.txt similarity index 100% rename from plugins/ocr/requirements.txt rename to memos/plugins/ocr/requirements.txt diff --git a/plugins/ocr/server.py b/memos/plugins/ocr/server.py similarity index 100% rename from plugins/ocr/server.py rename to memos/plugins/ocr/server.py diff --git a/memos/server.py b/memos/server.py index bf9e6da..49b67a7 100644 --- a/memos/server.py +++ b/memos/server.py @@ -88,6 +88,18 @@ app.mount( "/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True) ) +# Add VLM plugin router +if settings.vlm.enabled: + print("VLM plugin is enabled") + vlm_main.init_plugin(settings.vlm) + app.include_router(vlm_main.router, prefix="/plugins/vlm") + +# Add OCR plugin router +if settings.ocr.enabled: + print("OCR plugin is enabled") + ocr_main.init_plugin(settings.ocr) + app.include_router(ocr_main.router, prefix="/plugins/ocr") + @app.get("/favicon.png", response_class=FileResponse) async def favicon_png(): @@ -178,8 +190,12 @@ async def trigger_webhooks( location = str( request.url_for("get_entity_by_id", entity_id=entity.id) ) + webhook_url = plugin.webhook_url + if webhook_url.startswith("/"): + webhook_url = str(request.base_url)[:-1] + webhook_url + print(f"webhook_url: {webhook_url}") task = client.post( - plugin.webhook_url, + webhook_url, json=entity.model_dump(mode="json"), headers={"Location": location}, timeout=60.0, @@ -683,19 +699,6 @@ async def get_file(file_path: str): raise HTTPException(status_code=404, detail="File not found") -# Add VLM plugin router -if settings.vlm.enabled: - print("VLM plugin is enabled") - vlm_main.init_plugin(settings.vlm) - app.include_router(vlm_main.router, prefix=f"/plugins/{vlm_main.PLUGIN_NAME}") - -# Add OCR plugin router -if settings.ocr.enabled: - print("OCR plugin is enabled") - ocr_main.init_plugin(settings.ocr) - app.include_router(ocr_main.router, prefix=f"/plugins/{ocr_main.PLUGIN_NAME}") - - def run_server(): print("Database path:", get_database_path()) print(