feat(ocr): use multiple process for ppocr

This commit is contained in:
arkohut 2024-09-19 22:12:35 +08:00
parent 7b346b9e25
commit a47ee00e59

View File

@ -5,10 +5,12 @@ from fastapi import FastAPI, Body, HTTPException
import base64
import io
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pydantic import BaseModel, Field
from typing import List
from multiprocessing import Pool
import threading
import time
import uvicorn
# Configure logger
logging.basicConfig(level=logging.INFO)
@ -17,36 +19,48 @@ logger = logging.getLogger(__name__)
app = FastAPI()
# Initialize OCR engine (will be updated later)
ocr = None
# 创建一个线程池
thread_pool = None
# 创建进程池
process_pool = None
def init_thread_pool(max_workers):
global thread_pool
thread_pool = ThreadPoolExecutor(max_workers=max_workers)
def init_worker(use_gpu):
global ocr
ocr = init_ocr(use_gpu)
def init_process_pool(max_workers, use_gpu):
global process_pool
process_pool = Pool(
processes=max_workers, initializer=init_worker, initargs=(use_gpu,)
)
def init_ocr(use_gpu):
global ocr
if use_gpu:
try:
from rapidocr_paddle import RapidOCR as RapidOCRPaddle
ocr = RapidOCRPaddle(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
ocr = RapidOCRPaddle(
det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True
)
logger.info("Initialized OCR with RapidOCR Paddle (GPU)")
except ImportError:
logger.error("Failed to import rapidocr_paddle. Make sure it's installed for GPU usage.")
logger.error(
"Failed to import rapidocr_paddle. Make sure it's installed for GPU usage."
)
raise
else:
try:
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR(config_path="ppocr.yaml")
logger.info("Initialized OCR with RapidOCR ONNX Runtime (CPU)")
except ImportError:
logger.error("Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage.")
logger.error(
"Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage."
)
raise
return ocr
def convert_ocr_results(results):
@ -60,13 +74,15 @@ def convert_ocr_results(results):
return converted
def predict(image: Image.Image):
# Convert PIL Image to numpy array if necessary
def predict(image_data):
global ocr
if ocr is None:
raise ValueError("OCR engine not initialized")
image = Image.open(io.BytesIO(image_data))
img_array = np.array(image)
results, _ = ocr(img_array)
# Convert results to desired format
converted_results = convert_ocr_results(results)
return converted_results
@ -83,10 +99,11 @@ def convert_to_python_type(item):
return item
async def async_predict(image: Image.Image):
async def async_predict(image_data):
loop = asyncio.get_running_loop()
# 在线程池中运行同步的 OCR 处理
results = await loop.run_in_executor(thread_pool, partial(predict, image))
results = await loop.run_in_executor(
None, process_pool.apply, predict, (image_data,)
)
return results
@ -108,10 +125,9 @@ async def predict_base64(image_base64: str = Body(..., embed=True)):
# Decode the base64 image
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data))
# 使用异步函数进行 OCR 处理
ocr_result = await async_predict(image)
# 直接传递图像数据给async_predict
ocr_result = await async_predict(image_data)
return convert_to_python_type(ocr_result)
@ -120,6 +136,40 @@ async def predict_base64(image_base64: str = Body(..., embed=True)):
raise HTTPException(status_code=500, detail=str(e))
shutdown_event = threading.Event()
def signal_handler(signum, frame):
logger.info("Received interrupt signal. Initiating shutdown...")
shutdown_event.set()
def run_server(app, host, port):
config = uvicorn.Config(app, host=host, port=port, loop="asyncio")
server = uvicorn.Server(config)
server.install_signal_handlers = (
lambda: None
) # Disable Uvicorn's own signal handlers
async def serve():
await server.serve()
thread = threading.Thread(target=asyncio.run, args=(serve(),))
thread.start()
try:
while not shutdown_event.is_set():
time.sleep(1)
except KeyboardInterrupt:
logger.info("Keyboard interrupt received. Initiating shutdown...")
finally:
shutdown_event.set()
logger.info("Stopping the server...")
asyncio.run(server.shutdown())
thread.join()
logger.info("Server stopped.")
if __name__ == "__main__":
import uvicorn
import argparse
@ -148,7 +198,12 @@ if __name__ == "__main__":
max_workers = args.max_workers
use_gpu = args.gpu
init_thread_pool(max_workers)
init_ocr(use_gpu)
uvicorn.run(app, host="0.0.0.0", port=port)
try:
init_process_pool(max_workers, use_gpu)
run_server(app, "0.0.0.0", port)
finally:
logger.info("Shutting down process pool...")
if process_pool:
process_pool.close()
process_pool.join()
logger.info("Process pool shut down.")