From 9642f97535760704450a87ac1b10754fb5176a75 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 11 Jun 2024 15:18:51 +0800 Subject: [PATCH] feat: add ocr plugin example --- plugins/ocr/main.py | 89 ++++++++++++++++++++++++++++++++++++ plugins/ocr/requirements.txt | 3 ++ 2 files changed, 92 insertions(+) create mode 100644 plugins/ocr/main.py create mode 100644 plugins/ocr/requirements.txt diff --git a/plugins/ocr/main.py b/plugins/ocr/main.py new file mode 100644 index 0000000..01d1926 --- /dev/null +++ b/plugins/ocr/main.py @@ -0,0 +1,89 @@ +import httpx +import json + +from fastapi import FastAPI, Request, HTTPException +from memos.schemas import Entity, MetadataType + +from rapidocr_onnxruntime import RapidOCR, VisRes + + +engine = RapidOCR() +vis = VisRes() + +METADATA_FIELD_NAME = "ocr_result" +PLUGIN_NAME = "ocr" + + +def predict(img_path): + result, elapse = engine(img_path) + if result is None: + return None, None + return [ + {"dt_boxes": item[0], "rec_txt": item[1], "score": item[2]} for item in result + ], elapse + + +app = FastAPI() + + +@app.get("/") +async def read_root(): + return {"healthy": True} + + +@app.post("/") +async def ocr(entity: Entity, request: Request): + if not entity.file_type.startswith("image/"): + return {METADATA_FIELD_NAME: "{}"} + + # Get the URL to patch the entity's metadata from the "Location" header + location_url = request.headers.get("Location") + if not location_url: + raise HTTPException(status_code=400, detail="Location header is missing") + + patch_url = f"{location_url}/metadata" + + ocr_result, _ = predict(entity.filepath) + + print(ocr_result) + if ocr_result is None or not ocr_result: + print(f"No OCR result found for file: {entity.filepath}") + return {METADATA_FIELD_NAME: "{}"} + + # Call the URL to patch the entity's metadata + async with httpx.AsyncClient() as client: + response = await client.patch( + patch_url, + json={ + "metadata_entries": [ + { + "key": METADATA_FIELD_NAME, + "value": json.dumps( + ocr_result, + default=lambda o: o.item() if hasattr(o, "item") else o, + ), + "source": PLUGIN_NAME, + "data_type": MetadataType.JSON_DATA.value, + } + ] + }, + ) + + # Check if the patch request was successful + if response.status_code != 200: + raise HTTPException( + status_code=response.status_code, detail="Failed to patch entity metadata" + ) + + return { + METADATA_FIELD_NAME: json.dumps( + ocr_result, + default=lambda o: o.item() if hasattr(o, "item") else o, + ) + } + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/plugins/ocr/requirements.txt b/plugins/ocr/requirements.txt new file mode 100644 index 0000000..24be216 --- /dev/null +++ b/plugins/ocr/requirements.txt @@ -0,0 +1,3 @@ +rapidocr_onnxruntime +httpx +fastapi