From e19ebe605436befdfc58931cb30634c614c07f57 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Sat, 17 Aug 2024 00:21:38 +0800 Subject: [PATCH] feat(ocr): update main --- plugins/ocr/main.py | 90 +++++++++++++++++++++++++++++++++++-------- plugins/ocr/server.py | 10 +++++ 2 files changed, 85 insertions(+), 15 deletions(-) diff --git a/plugins/ocr/main.py b/plugins/ocr/main.py index a1dd52f..d01f4a4 100644 --- a/plugins/ocr/main.py +++ b/plugins/ocr/main.py @@ -1,26 +1,66 @@ +import asyncio +import logging +from typing import Optional import httpx import json +import base64 +import io +import os +from PIL import Image from fastapi import FastAPI, Request, HTTPException from memos.schemas import Entity, MetadataType -from rapidocr_onnxruntime import RapidOCR, VisRes - - -engine = RapidOCR() -vis = VisRes() - METADATA_FIELD_NAME = "ocr_result" PLUGIN_NAME = "ocr" +app = FastAPI() -def predict(img_path): - result, elapse = engine(img_path) - if result is None: - return None, None - return [ - {"dt_boxes": item[0], "rec_txt": item[1], "score": item[2]} for item in result - ], elapse +endpoint = None +token = None +semaphore = asyncio.Semaphore(4) + +# Configure logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def image2base64(img_path): + try: + with Image.open(img_path) as img: + img.convert("RGB") # Check if image is not broken + with open(img_path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + return encoded_string + except Exception as e: + logger.error(f"Error processing image {img_path}: {str(e)}") + return None + + +async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = None): + async with semaphore: # 使用信号量控制并发 + response = await client.post( + f"{endpoint}", + json={"image_base64": image_base64}, + timeout=60, + headers=headers, + ) + if response.status_code != 200: + return None + return response.json() + + +async def predict(img_path): + image_base64 = image2base64(img_path) + if not image_base64: + return None + + async with httpx.AsyncClient() as client: + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + ocr_result = await fetch(endpoint, client, image_base64, headers=headers) + return ocr_result app = FastAPI() @@ -43,7 +83,7 @@ async def ocr(entity: Entity, request: Request): patch_url = f"{location_url}/metadata" - ocr_result, _ = predict(entity.filepath) + ocr_result = await predict(entity.filepath) print(ocr_result) if ocr_result is None or not ocr_result: @@ -85,5 +125,25 @@ async def ocr(entity: Entity, request: Request): if __name__ == "__main__": import uvicorn + import argparse - uvicorn.run(app, host="0.0.0.0", port=8000) + parser = argparse.ArgumentParser(description="OCR Plugin") + parser.add_argument( + "--endpoint", + type=str, + required=True, + help="The endpoint URL for the OCR service", + ) + parser.add_argument( + "--token", type=str, required=False, help="The token for authentication" + ) + parser.add_argument( + "--port", type=int, default=8000, help="The port number to run the server on" + ) + + args = parser.parse_args() + endpoint = args.endpoint + token = args.token + port = args.port + + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/plugins/ocr/server.py b/plugins/ocr/server.py index cabb39c..5f7b8c7 100644 --- a/plugins/ocr/server.py +++ b/plugins/ocr/server.py @@ -1,6 +1,7 @@ from rapidocr_onnxruntime import RapidOCR from PIL import Image import numpy as np +import logging from fastapi import FastAPI, Body, HTTPException import base64 import io @@ -10,6 +11,11 @@ from functools import partial from pydantic import BaseModel, Field from typing import List +# Configure logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + app = FastAPI() # Initialize OCR engine (will be updated later) @@ -31,6 +37,9 @@ def init_ocr(use_gpu): def convert_ocr_results(results): + if results is None: + return [] + converted = [] for result in results: item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]} @@ -94,6 +103,7 @@ async def predict_base64(image_base64: str = Body(..., embed=True)): return convert_to_python_type(ocr_result) except Exception as e: + logging.error(f"Error during OCR processing: {str(e)}") raise HTTPException(status_code=500, detail=str(e))