diff --git a/memos/plugins/ocr/README.md b/memos/plugins/ocr/README.md new file mode 100644 index 0000000..65dae4a --- /dev/null +++ b/memos/plugins/ocr/README.md @@ -0,0 +1,49 @@ +# OCR Plugin + +This is a README file for the OCR plugin. This plugin uses the `RapidOCR` library to perform OCR (Optical Character Recognition) on image files and updates the metadata of the entity with the OCR results. + +## How to Run + +To run this OCR plugin, follow the steps below: + +1. **Install the required dependencies:** + + ```bash + pip install -r requirements.txt + ``` + +2. **Run the FastAPI application:** + + You can run the FastAPI application using `uvicorn`. Make sure you are in the directory where `main.py` is located. + + ```bash + uvicorn main:app --host 0.0.0.0 --port 8000 + ``` + +3. **Integration with memos:** + + ```sh + $ python -m memos.commands plugin create ocr http://localhost:8000 + Plugin created successfully + ``` + + ```sh + $ python -m memos.commands plugin ls + + ID Name Description Webhook URL + 1 ocr http://localhost:8000/ + ``` + + ```sh + $ python -m memos.commands plugin bind --lib 1 --plugin 1 + Plugin bound to library successfully + ``` + +## Endpoints + +- `GET /`: Health check endpoint. Returns `{"healthy": True}` if the service is running. +- `POST /`: OCR endpoint. Accepts an `Entity` object and a `Location` header. Performs OCR on the image file and updates the entity's metadata with the OCR results. + +## Metadata + +The OCR results are stored in the metadata field named `ocr_result` with the following structure: diff --git a/memos/plugins/ocr/fonts/simfang.ttf b/memos/plugins/ocr/fonts/simfang.ttf new file mode 100644 index 0000000..2b59eae Binary files /dev/null and b/memos/plugins/ocr/fonts/simfang.ttf differ diff --git a/memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx b/memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx new file mode 100644 index 0000000..3046e38 Binary files /dev/null and b/memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx differ diff --git a/memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx b/memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx new file mode 100644 index 0000000..9984d54 Binary files /dev/null and b/memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx differ diff --git a/memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx b/memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx new file mode 100644 index 0000000..f3c651e Binary files /dev/null and b/memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx differ diff --git a/memos/plugins/ocr/ppocr-gpu.yaml b/memos/plugins/ocr/ppocr-gpu.yaml new file mode 100644 index 0000000..2a05d4b --- /dev/null +++ b/memos/plugins/ocr/ppocr-gpu.yaml @@ -0,0 +1,41 @@ +Global: + text_score: 0.5 + use_det: true + use_cls: true + use_rec: true + print_verbose: false + min_height: 30 + width_height_ratio: 40 + +Det: + use_cuda: true + + model_path: models/ch_PP-OCRv4_det_infer.onnx + + limit_side_len: 1500 + limit_type: min + + thresh: 0.3 + box_thresh: 0.3 + max_candidates: 1000 + unclip_ratio: 1.6 + use_dilation: true + score_mode: fast + +Cls: + use_cuda: true + + model_path: models/ch_ppocr_mobile_v2.0_cls_train.onnx + + cls_image_shape: [3, 48, 192] + cls_batch_num: 6 + cls_thresh: 0.9 + label_list: ['0', '180'] + +Rec: + use_cuda: true + + model_path: models/ch_PP-OCRv4_rec_infer.onnx + + rec_img_shape: [3, 48, 320] + rec_batch_num: 6 \ No newline at end of file diff --git a/memos/plugins/ocr/ppocr.yaml b/memos/plugins/ocr/ppocr.yaml new file mode 100644 index 0000000..5a93b55 --- /dev/null +++ b/memos/plugins/ocr/ppocr.yaml @@ -0,0 +1,41 @@ +Global: + text_score: 0.5 + use_det: true + use_cls: true + use_rec: true + print_verbose: false + min_height: 30 + width_height_ratio: 40 + +Det: + use_cuda: false + + model_path: models/ch_PP-OCRv4_det_infer.onnx + + limit_side_len: 1500 + limit_type: min + + thresh: 0.3 + box_thresh: 0.3 + max_candidates: 1000 + unclip_ratio: 1.6 + use_dilation: true + score_mode: fast + +Cls: + use_cuda: false + + model_path: models/ch_ppocr_mobile_v2.0_cls_train.onnx + + cls_image_shape: [3, 48, 192] + cls_batch_num: 6 + cls_thresh: 0.9 + label_list: ['0', '180'] + +Rec: + use_cuda: false + + model_path: models/ch_PP-OCRv4_rec_infer.onnx + + rec_img_shape: [3, 48, 320] + rec_batch_num: 6 \ No newline at end of file diff --git a/memos/plugins/ocr/requirements.txt b/memos/plugins/ocr/requirements.txt new file mode 100644 index 0000000..075cb29 --- /dev/null +++ b/memos/plugins/ocr/requirements.txt @@ -0,0 +1,4 @@ +rapidocr_onnxruntime +httpx +fastapi +# Note: If you are using GPU, you should add onnxruntime-gpu to the requirements diff --git a/memos/plugins/ocr/server.py b/memos/plugins/ocr/server.py new file mode 100644 index 0000000..9bf8d31 --- /dev/null +++ b/memos/plugins/ocr/server.py @@ -0,0 +1,209 @@ +from PIL import Image +import numpy as np +import logging +from fastapi import FastAPI, Body, HTTPException +import base64 +import io +import asyncio +from pydantic import BaseModel, Field +from typing import List +from multiprocessing import Pool +import threading +import time +import uvicorn + +# Configure logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +app = FastAPI() + +# 创建进程池 +process_pool = None + + +def init_worker(use_gpu): + global ocr + ocr = init_ocr(use_gpu) + + +def init_process_pool(max_workers, use_gpu): + global process_pool + process_pool = Pool( + processes=max_workers, initializer=init_worker, initargs=(use_gpu,) + ) + + +def init_ocr(use_gpu): + if use_gpu: + try: + from rapidocr_paddle import RapidOCR as RapidOCRPaddle + + ocr = RapidOCRPaddle( + det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True + ) + logger.info("Initialized OCR with RapidOCR Paddle (GPU)") + except ImportError: + logger.error( + "Failed to import rapidocr_paddle. Make sure it's installed for GPU usage." + ) + raise + else: + try: + from rapidocr_onnxruntime import RapidOCR + + ocr = RapidOCR(config_path="ppocr.yaml") + logger.info("Initialized OCR with RapidOCR ONNX Runtime (CPU)") + except ImportError: + logger.error( + "Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage." + ) + raise + return ocr + + +def convert_ocr_results(results): + if results is None: + return [] + + converted = [] + for result in results: + item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]} + converted.append(item) + return converted + + +def predict(image_data): + global ocr + if ocr is None: + raise ValueError("OCR engine not initialized") + + image = Image.open(io.BytesIO(image_data)) + img_array = np.array(image) + results, _ = ocr(img_array) + converted_results = convert_ocr_results(results) + return converted_results + + +def convert_to_python_type(item): + if isinstance(item, np.ndarray): + return item.tolist() + elif isinstance(item, np.generic): # This includes numpy scalars like numpy.float32 + return item.item() + elif isinstance(item, list): + return [convert_to_python_type(sub_item) for sub_item in item] + elif isinstance(item, dict): + return {key: convert_to_python_type(value) for key, value in item.items()} + else: + return item + + +async def async_predict(image_data): + loop = asyncio.get_running_loop() + results = await loop.run_in_executor( + None, process_pool.apply, predict, (image_data,) + ) + return results + + +class OCRResult(BaseModel): + dt_boxes: List[List[float]] = Field(..., description="Bounding box coordinates") + rec_txt: str = Field(..., description="Recognized text") + score: float = Field(..., description="Confidence score") + + +@app.post("/predict", response_model=List[OCRResult]) +async def predict_base64(image_base64: str = Body(..., embed=True)): + try: + if not image_base64: + raise HTTPException(status_code=400, detail="Missing image_base64 field") + + # Remove header part if present + if image_base64.startswith("data:image"): + image_base64 = image_base64.split(",")[1] + + # Decode the base64 image + image_data = base64.b64decode(image_base64) + + # 直接传递图像数据给async_predict + ocr_result = await async_predict(image_data) + + return convert_to_python_type(ocr_result) + + except Exception as e: + logging.error(f"Error during OCR processing: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +shutdown_event = threading.Event() + + +def signal_handler(signum, frame): + logger.info("Received interrupt signal. Initiating shutdown...") + shutdown_event.set() + + +def run_server(app, host, port): + config = uvicorn.Config(app, host=host, port=port, loop="asyncio") + server = uvicorn.Server(config) + server.install_signal_handlers = ( + lambda: None + ) # Disable Uvicorn's own signal handlers + + async def serve(): + await server.serve() + + thread = threading.Thread(target=asyncio.run, args=(serve(),)) + thread.start() + + try: + while not shutdown_event.is_set(): + time.sleep(1) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received. Initiating shutdown...") + finally: + shutdown_event.set() + logger.info("Stopping the server...") + asyncio.run(server.shutdown()) + thread.join() + logger.info("Server stopped.") + + +if __name__ == "__main__": + import uvicorn + import argparse + + parser = argparse.ArgumentParser(description="OCR Service") + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the OCR service on", + ) + parser.add_argument( + "--max-workers", + type=int, + default=1, + help="Maximum number of worker threads for OCR processing", + ) + parser.add_argument( + "--gpu", + action="store_true", + help="Use GPU for OCR processing", + ) + + args = parser.parse_args() + port = args.port + max_workers = args.max_workers + use_gpu = args.gpu + + try: + init_process_pool(max_workers, use_gpu) + run_server(app, "0.0.0.0", port) + finally: + logger.info("Shutting down process pool...") + if process_pool: + process_pool.close() + process_pool.join() + logger.info("Process pool shut down.") diff --git a/pyproject.toml b/pyproject.toml index e494e2d..2a4f6ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "pillow", "piexif", "imagehash", + "rapidocr_onnxruntime", "screeninfo", "pywin32; sys_platform == 'win32'", "psutil; sys_platform == 'win32'",