diff --git a/memos/plugins/ocr/server.py b/memos/plugins/ocr/server.py index 5f7b8c7..e9a1d69 100644 --- a/memos/plugins/ocr/server.py +++ b/memos/plugins/ocr/server.py @@ -1,4 +1,3 @@ -from rapidocr_onnxruntime import RapidOCR from PIL import Image import numpy as np import logging @@ -32,14 +31,28 @@ def init_thread_pool(max_workers): def init_ocr(use_gpu): global ocr - config_path = "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml" - ocr = RapidOCR(config_path=config_path) + if use_gpu: + try: + from rapidocr_paddle import RapidOCR as RapidOCRPaddle + ocr = RapidOCRPaddle(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True) + logger.info("Initialized OCR with RapidOCR Paddle (GPU)") + except ImportError: + logger.error("Failed to import rapidocr_paddle. Make sure it's installed for GPU usage.") + raise + else: + try: + from rapidocr_onnxruntime import RapidOCR + ocr = RapidOCR(config_path="ppocr.yaml") + logger.info("Initialized OCR with RapidOCR ONNX Runtime (CPU)") + except ImportError: + logger.error("Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage.") + raise def convert_ocr_results(results): if results is None: return [] - + converted = [] for result in results: item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]} @@ -138,4 +151,4 @@ if __name__ == "__main__": init_thread_pool(max_workers) init_ocr(use_gpu) - uvicorn.run(app, host="0.0.0.0", port=port) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=port)