mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
feat(ocr): use multiple process for ppocr
This commit is contained in:
parent
7b346b9e25
commit
a47ee00e59
@ -5,10 +5,12 @@ from fastapi import FastAPI, Body, HTTPException
|
||||
import base64
|
||||
import io
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
from multiprocessing import Pool
|
||||
import threading
|
||||
import time
|
||||
import uvicorn
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -17,36 +19,48 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize OCR engine (will be updated later)
|
||||
ocr = None
|
||||
|
||||
# 创建一个线程池
|
||||
thread_pool = None
|
||||
# 创建进程池
|
||||
process_pool = None
|
||||
|
||||
|
||||
def init_thread_pool(max_workers):
|
||||
global thread_pool
|
||||
thread_pool = ThreadPoolExecutor(max_workers=max_workers)
|
||||
def init_worker(use_gpu):
|
||||
global ocr
|
||||
ocr = init_ocr(use_gpu)
|
||||
|
||||
|
||||
def init_process_pool(max_workers, use_gpu):
|
||||
global process_pool
|
||||
process_pool = Pool(
|
||||
processes=max_workers, initializer=init_worker, initargs=(use_gpu,)
|
||||
)
|
||||
|
||||
|
||||
def init_ocr(use_gpu):
|
||||
global ocr
|
||||
if use_gpu:
|
||||
try:
|
||||
from rapidocr_paddle import RapidOCR as RapidOCRPaddle
|
||||
ocr = RapidOCRPaddle(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
||||
|
||||
ocr = RapidOCRPaddle(
|
||||
det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True
|
||||
)
|
||||
logger.info("Initialized OCR with RapidOCR Paddle (GPU)")
|
||||
except ImportError:
|
||||
logger.error("Failed to import rapidocr_paddle. Make sure it's installed for GPU usage.")
|
||||
logger.error(
|
||||
"Failed to import rapidocr_paddle. Make sure it's installed for GPU usage."
|
||||
)
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
|
||||
ocr = RapidOCR(config_path="ppocr.yaml")
|
||||
logger.info("Initialized OCR with RapidOCR ONNX Runtime (CPU)")
|
||||
except ImportError:
|
||||
logger.error("Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage.")
|
||||
logger.error(
|
||||
"Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage."
|
||||
)
|
||||
raise
|
||||
return ocr
|
||||
|
||||
|
||||
def convert_ocr_results(results):
|
||||
@ -60,13 +74,15 @@ def convert_ocr_results(results):
|
||||
return converted
|
||||
|
||||
|
||||
def predict(image: Image.Image):
|
||||
# Convert PIL Image to numpy array if necessary
|
||||
def predict(image_data):
|
||||
global ocr
|
||||
if ocr is None:
|
||||
raise ValueError("OCR engine not initialized")
|
||||
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
img_array = np.array(image)
|
||||
results, _ = ocr(img_array)
|
||||
# Convert results to desired format
|
||||
converted_results = convert_ocr_results(results)
|
||||
|
||||
return converted_results
|
||||
|
||||
|
||||
@ -83,10 +99,11 @@ def convert_to_python_type(item):
|
||||
return item
|
||||
|
||||
|
||||
async def async_predict(image: Image.Image):
|
||||
async def async_predict(image_data):
|
||||
loop = asyncio.get_running_loop()
|
||||
# 在线程池中运行同步的 OCR 处理
|
||||
results = await loop.run_in_executor(thread_pool, partial(predict, image))
|
||||
results = await loop.run_in_executor(
|
||||
None, process_pool.apply, predict, (image_data,)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@ -108,10 +125,9 @@ async def predict_base64(image_base64: str = Body(..., embed=True)):
|
||||
|
||||
# Decode the base64 image
|
||||
image_data = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 使用异步函数进行 OCR 处理
|
||||
ocr_result = await async_predict(image)
|
||||
# 直接传递图像数据给async_predict
|
||||
ocr_result = await async_predict(image_data)
|
||||
|
||||
return convert_to_python_type(ocr_result)
|
||||
|
||||
@ -120,6 +136,40 @@ async def predict_base64(image_base64: str = Body(..., embed=True)):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logger.info("Received interrupt signal. Initiating shutdown...")
|
||||
shutdown_event.set()
|
||||
|
||||
|
||||
def run_server(app, host, port):
|
||||
config = uvicorn.Config(app, host=host, port=port, loop="asyncio")
|
||||
server = uvicorn.Server(config)
|
||||
server.install_signal_handlers = (
|
||||
lambda: None
|
||||
) # Disable Uvicorn's own signal handlers
|
||||
|
||||
async def serve():
|
||||
await server.serve()
|
||||
|
||||
thread = threading.Thread(target=asyncio.run, args=(serve(),))
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Keyboard interrupt received. Initiating shutdown...")
|
||||
finally:
|
||||
shutdown_event.set()
|
||||
logger.info("Stopping the server...")
|
||||
asyncio.run(server.shutdown())
|
||||
thread.join()
|
||||
logger.info("Server stopped.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
import argparse
|
||||
@ -148,7 +198,12 @@ if __name__ == "__main__":
|
||||
max_workers = args.max_workers
|
||||
use_gpu = args.gpu
|
||||
|
||||
init_thread_pool(max_workers)
|
||||
init_ocr(use_gpu)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
try:
|
||||
init_process_pool(max_workers, use_gpu)
|
||||
run_server(app, "0.0.0.0", port)
|
||||
finally:
|
||||
logger.info("Shutting down process pool...")
|
||||
if process_pool:
|
||||
process_pool.close()
|
||||
process_pool.join()
|
||||
logger.info("Process pool shut down.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user