diff --git a/memos/plugins/ocr/server.py b/memos/plugins/ocr/server.py index e9a1d69..9bf8d31 100644 --- a/memos/plugins/ocr/server.py +++ b/memos/plugins/ocr/server.py @@ -5,10 +5,12 @@ from fastapi import FastAPI, Body, HTTPException import base64 import io import asyncio -from concurrent.futures import ThreadPoolExecutor -from functools import partial from pydantic import BaseModel, Field from typing import List +from multiprocessing import Pool +import threading +import time +import uvicorn # Configure logger logging.basicConfig(level=logging.INFO) @@ -17,36 +19,48 @@ logger = logging.getLogger(__name__) app = FastAPI() -# Initialize OCR engine (will be updated later) -ocr = None - -# 创建一个线程池 -thread_pool = None +# 创建进程池 +process_pool = None -def init_thread_pool(max_workers): - global thread_pool - thread_pool = ThreadPoolExecutor(max_workers=max_workers) +def init_worker(use_gpu): + global ocr + ocr = init_ocr(use_gpu) + + +def init_process_pool(max_workers, use_gpu): + global process_pool + process_pool = Pool( + processes=max_workers, initializer=init_worker, initargs=(use_gpu,) + ) def init_ocr(use_gpu): - global ocr if use_gpu: try: from rapidocr_paddle import RapidOCR as RapidOCRPaddle - ocr = RapidOCRPaddle(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True) + + ocr = RapidOCRPaddle( + det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True + ) logger.info("Initialized OCR with RapidOCR Paddle (GPU)") except ImportError: - logger.error("Failed to import rapidocr_paddle. Make sure it's installed for GPU usage.") + logger.error( + "Failed to import rapidocr_paddle. Make sure it's installed for GPU usage." + ) raise else: try: from rapidocr_onnxruntime import RapidOCR + ocr = RapidOCR(config_path="ppocr.yaml") logger.info("Initialized OCR with RapidOCR ONNX Runtime (CPU)") except ImportError: - logger.error("Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage.") + logger.error( + "Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage." + ) raise + return ocr def convert_ocr_results(results): @@ -60,13 +74,15 @@ def convert_ocr_results(results): return converted -def predict(image: Image.Image): - # Convert PIL Image to numpy array if necessary +def predict(image_data): + global ocr + if ocr is None: + raise ValueError("OCR engine not initialized") + + image = Image.open(io.BytesIO(image_data)) img_array = np.array(image) results, _ = ocr(img_array) - # Convert results to desired format converted_results = convert_ocr_results(results) - return converted_results @@ -83,10 +99,11 @@ def convert_to_python_type(item): return item -async def async_predict(image: Image.Image): +async def async_predict(image_data): loop = asyncio.get_running_loop() - # 在线程池中运行同步的 OCR 处理 - results = await loop.run_in_executor(thread_pool, partial(predict, image)) + results = await loop.run_in_executor( + None, process_pool.apply, predict, (image_data,) + ) return results @@ -108,10 +125,9 @@ async def predict_base64(image_base64: str = Body(..., embed=True)): # Decode the base64 image image_data = base64.b64decode(image_base64) - image = Image.open(io.BytesIO(image_data)) - # 使用异步函数进行 OCR 处理 - ocr_result = await async_predict(image) + # 直接传递图像数据给async_predict + ocr_result = await async_predict(image_data) return convert_to_python_type(ocr_result) @@ -120,6 +136,40 @@ async def predict_base64(image_base64: str = Body(..., embed=True)): raise HTTPException(status_code=500, detail=str(e)) +shutdown_event = threading.Event() + + +def signal_handler(signum, frame): + logger.info("Received interrupt signal. Initiating shutdown...") + shutdown_event.set() + + +def run_server(app, host, port): + config = uvicorn.Config(app, host=host, port=port, loop="asyncio") + server = uvicorn.Server(config) + server.install_signal_handlers = ( + lambda: None + ) # Disable Uvicorn's own signal handlers + + async def serve(): + await server.serve() + + thread = threading.Thread(target=asyncio.run, args=(serve(),)) + thread.start() + + try: + while not shutdown_event.is_set(): + time.sleep(1) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received. Initiating shutdown...") + finally: + shutdown_event.set() + logger.info("Stopping the server...") + asyncio.run(server.shutdown()) + thread.join() + logger.info("Server stopped.") + + if __name__ == "__main__": import uvicorn import argparse @@ -148,7 +198,12 @@ if __name__ == "__main__": max_workers = args.max_workers use_gpu = args.gpu - init_thread_pool(max_workers) - init_ocr(use_gpu) - - uvicorn.run(app, host="0.0.0.0", port=port) + try: + init_process_pool(max_workers, use_gpu) + run_server(app, "0.0.0.0", port) + finally: + logger.info("Shutting down process pool...") + if process_pool: + process_pool.close() + process_pool.join() + logger.info("Process pool shut down.")