feat(ocr): update main

This commit is contained in:
arkohut 2024-08-17 00:21:38 +08:00
parent 44f015d3fe
commit e19ebe6054
2 changed files with 85 additions and 15 deletions

View File

@ -1,26 +1,66 @@
import asyncio
import logging
from typing import Optional
import httpx import httpx
import json import json
import base64
import io
import os
from PIL import Image
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, Request, HTTPException
from memos.schemas import Entity, MetadataType from memos.schemas import Entity, MetadataType
from rapidocr_onnxruntime import RapidOCR, VisRes
engine = RapidOCR()
vis = VisRes()
METADATA_FIELD_NAME = "ocr_result" METADATA_FIELD_NAME = "ocr_result"
PLUGIN_NAME = "ocr" PLUGIN_NAME = "ocr"
app = FastAPI()
def predict(img_path): endpoint = None
result, elapse = engine(img_path) token = None
if result is None: semaphore = asyncio.Semaphore(4)
return None, None
return [ # Configure logger
{"dt_boxes": item[0], "rec_txt": item[1], "score": item[2]} for item in result logging.basicConfig(level=logging.INFO)
], elapse logger = logging.getLogger(__name__)
def image2base64(img_path):
try:
with Image.open(img_path) as img:
img.convert("RGB") # Check if image is not broken
with open(img_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return encoded_string
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
return None
async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = None):
async with semaphore: # 使用信号量控制并发
response = await client.post(
f"{endpoint}",
json={"image_base64": image_base64},
timeout=60,
headers=headers,
)
if response.status_code != 200:
return None
return response.json()
async def predict(img_path):
image_base64 = image2base64(img_path)
if not image_base64:
return None
async with httpx.AsyncClient() as client:
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
ocr_result = await fetch(endpoint, client, image_base64, headers=headers)
return ocr_result
app = FastAPI() app = FastAPI()
@ -43,7 +83,7 @@ async def ocr(entity: Entity, request: Request):
patch_url = f"{location_url}/metadata" patch_url = f"{location_url}/metadata"
ocr_result, _ = predict(entity.filepath) ocr_result = await predict(entity.filepath)
print(ocr_result) print(ocr_result)
if ocr_result is None or not ocr_result: if ocr_result is None or not ocr_result:
@ -85,5 +125,25 @@ async def ocr(entity: Entity, request: Request):
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
import argparse
uvicorn.run(app, host="0.0.0.0", port=8000) parser = argparse.ArgumentParser(description="OCR Plugin")
parser.add_argument(
"--endpoint",
type=str,
required=True,
help="The endpoint URL for the OCR service",
)
parser.add_argument(
"--token", type=str, required=False, help="The token for authentication"
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to run the server on"
)
args = parser.parse_args()
endpoint = args.endpoint
token = args.token
port = args.port
uvicorn.run(app, host="0.0.0.0", port=port)

View File

@ -1,6 +1,7 @@
from rapidocr_onnxruntime import RapidOCR from rapidocr_onnxruntime import RapidOCR
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import logging
from fastapi import FastAPI, Body, HTTPException from fastapi import FastAPI, Body, HTTPException
import base64 import base64
import io import io
@ -10,6 +11,11 @@ from functools import partial
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List from typing import List
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI() app = FastAPI()
# Initialize OCR engine (will be updated later) # Initialize OCR engine (will be updated later)
@ -31,6 +37,9 @@ def init_ocr(use_gpu):
def convert_ocr_results(results): def convert_ocr_results(results):
if results is None:
return []
converted = [] converted = []
for result in results: for result in results:
item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]} item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]}
@ -94,6 +103,7 @@ async def predict_base64(image_base64: str = Body(..., embed=True)):
return convert_to_python_type(ocr_result) return convert_to_python_type(ocr_result)
except Exception as e: except Exception as e:
logging.error(f"Error during OCR processing: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))