mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(ocr): update main
This commit is contained in:
parent
44f015d3fe
commit
e19ebe6054
@ -1,26 +1,66 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
import httpx
|
||||
import json
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from memos.schemas import Entity, MetadataType
|
||||
|
||||
from rapidocr_onnxruntime import RapidOCR, VisRes
|
||||
|
||||
|
||||
engine = RapidOCR()
|
||||
vis = VisRes()
|
||||
|
||||
METADATA_FIELD_NAME = "ocr_result"
|
||||
PLUGIN_NAME = "ocr"
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
def predict(img_path):
|
||||
result, elapse = engine(img_path)
|
||||
if result is None:
|
||||
return None, None
|
||||
return [
|
||||
{"dt_boxes": item[0], "rec_txt": item[1], "score": item[2]} for item in result
|
||||
], elapse
|
||||
endpoint = None
|
||||
token = None
|
||||
semaphore = asyncio.Semaphore(4)
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def image2base64(img_path):
|
||||
try:
|
||||
with Image.open(img_path) as img:
|
||||
img.convert("RGB") # Check if image is not broken
|
||||
with open(img_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
return encoded_string
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = None):
|
||||
async with semaphore: # 使用信号量控制并发
|
||||
response = await client.post(
|
||||
f"{endpoint}",
|
||||
json={"image_base64": image_base64},
|
||||
timeout=60,
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
return response.json()
|
||||
|
||||
|
||||
async def predict(img_path):
|
||||
image_base64 = image2base64(img_path)
|
||||
if not image_base64:
|
||||
return None
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
ocr_result = await fetch(endpoint, client, image_base64, headers=headers)
|
||||
return ocr_result
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
@ -43,7 +83,7 @@ async def ocr(entity: Entity, request: Request):
|
||||
|
||||
patch_url = f"{location_url}/metadata"
|
||||
|
||||
ocr_result, _ = predict(entity.filepath)
|
||||
ocr_result = await predict(entity.filepath)
|
||||
|
||||
print(ocr_result)
|
||||
if ocr_result is None or not ocr_result:
|
||||
@ -85,5 +125,25 @@ async def ocr(entity: Entity, request: Request):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
import argparse
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
parser = argparse.ArgumentParser(description="OCR Plugin")
|
||||
parser.add_argument(
|
||||
"--endpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The endpoint URL for the OCR service",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token", type=str, required=False, help="The token for authentication"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="The port number to run the server on"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
endpoint = args.endpoint
|
||||
token = args.token
|
||||
port = args.port
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import logging
|
||||
from fastapi import FastAPI, Body, HTTPException
|
||||
import base64
|
||||
import io
|
||||
@ -10,6 +11,11 @@ from functools import partial
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize OCR engine (will be updated later)
|
||||
@ -31,6 +37,9 @@ def init_ocr(use_gpu):
|
||||
|
||||
|
||||
def convert_ocr_results(results):
|
||||
if results is None:
|
||||
return []
|
||||
|
||||
converted = []
|
||||
for result in results:
|
||||
item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]}
|
||||
@ -94,6 +103,7 @@ async def predict_base64(image_base64: str = Body(..., embed=True)):
|
||||
return convert_to_python_type(ocr_result)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error during OCR processing: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user