feat(ocr): update import flow

This commit is contained in:
arkohut 2024-09-05 18:15:27 +08:00
parent c1226117d6
commit e062636c08

View File

@ -1,4 +1,3 @@
from rapidocr_onnxruntime import RapidOCR
from PIL import Image
import numpy as np
import logging
@ -32,14 +31,28 @@ def init_thread_pool(max_workers):
def init_ocr(use_gpu):
global ocr
config_path = "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml"
ocr = RapidOCR(config_path=config_path)
if use_gpu:
try:
from rapidocr_paddle import RapidOCR as RapidOCRPaddle
ocr = RapidOCRPaddle(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
logger.info("Initialized OCR with RapidOCR Paddle (GPU)")
except ImportError:
logger.error("Failed to import rapidocr_paddle. Make sure it's installed for GPU usage.")
raise
else:
try:
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR(config_path="ppocr.yaml")
logger.info("Initialized OCR with RapidOCR ONNX Runtime (CPU)")
except ImportError:
logger.error("Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage.")
raise
def convert_ocr_results(results):
if results is None:
return []
converted = []
for result in results:
item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]}
@ -138,4 +151,4 @@ if __name__ == "__main__":
init_thread_pool(max_workers)
init_ocr(use_gpu)
uvicorn.run(app, host="0.0.0.0", port=port)
uvicorn.run(app, host="0.0.0.0", port=port)