mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(ocr): support local paddleocr
This commit is contained in:
parent
d3b45ad197
commit
2b2d616775
@ -25,12 +25,14 @@ class OCRSettings(BaseModel):
|
||||
endpoint: str = "http://localhost:5555/predict"
|
||||
token: str = ""
|
||||
concurrency: int = 4
|
||||
use_local: bool = True
|
||||
use_gpu: bool = False
|
||||
|
||||
|
||||
class EmbeddingSettings(BaseModel):
|
||||
num_dim: int = 768
|
||||
ollama_endpoint: str = "http://localhost:11434"
|
||||
ollama_model: str = "jina/jina-embeddings-v2-base-en"
|
||||
ollama_model: str = "nextfire/paraphrase-multilingual-minilm"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
|
@ -1,10 +1,16 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
import httpx
|
||||
import json
|
||||
import base64
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import yaml
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
||||
from memos.schemas import Entity, MetadataType
|
||||
@ -17,6 +23,10 @@ endpoint = None
|
||||
token = None
|
||||
concurrency = None
|
||||
semaphore = None
|
||||
use_local = False
|
||||
use_gpu = False
|
||||
ocr = None
|
||||
thread_pool = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -48,8 +58,39 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N
|
||||
return response.json()
|
||||
|
||||
|
||||
def convert_ocr_results(results):
|
||||
if results is None:
|
||||
return []
|
||||
|
||||
converted = []
|
||||
for result in results:
|
||||
item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]}
|
||||
converted.append(item)
|
||||
return converted
|
||||
|
||||
|
||||
def predict_local(img_path):
|
||||
try:
|
||||
with Image.open(img_path) as img:
|
||||
img_array = np.array(img)
|
||||
results, _ = ocr(img_array)
|
||||
return convert_ocr_results(results)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def async_predict_local(img_path):
|
||||
loop = asyncio.get_running_loop()
|
||||
results = await loop.run_in_executor(thread_pool, partial(predict_local, img_path))
|
||||
return results
|
||||
|
||||
|
||||
# Modify the predict function to use semaphore
|
||||
async def predict(img_path):
|
||||
if use_local:
|
||||
return await async_predict_local(img_path)
|
||||
|
||||
image_base64 = image2base64(img_path)
|
||||
if not image_base64:
|
||||
return None
|
||||
@ -121,16 +162,40 @@ async def ocr(entity: Entity, request: Request):
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global endpoint, token, concurrency, semaphore
|
||||
global endpoint, token, concurrency, semaphore, use_local, use_gpu, ocr, thread_pool
|
||||
endpoint = config.endpoint
|
||||
token = config.token
|
||||
concurrency = config.concurrency
|
||||
use_local = config.use_local
|
||||
use_gpu = config.use_gpu
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
if use_local:
|
||||
config_path = os.path.join(os.path.dirname(__file__), "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml")
|
||||
|
||||
# Load and update the config file with absolute model paths
|
||||
with open(config_path, 'r') as f:
|
||||
ocr_config = yaml.safe_load(f)
|
||||
|
||||
model_dir = os.path.join(os.path.dirname(__file__), "models")
|
||||
ocr_config['Det']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Det']['model_path']))
|
||||
ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path']))
|
||||
ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path']))
|
||||
|
||||
# Save the updated config to a temporary file
|
||||
temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml")
|
||||
with open(temp_config_path, 'w') as f:
|
||||
yaml.safe_dump(ocr_config, f)
|
||||
|
||||
ocr = RapidOCR(config_path=temp_config_path)
|
||||
thread_pool = ThreadPoolExecutor(max_workers=concurrency)
|
||||
|
||||
logger.info("OCR plugin initialized")
|
||||
logger.info(f"Endpoint: {endpoint}")
|
||||
logger.info(f"Token: {token}")
|
||||
logger.info(f"Concurrency: {concurrency}")
|
||||
logger.info(f"Use local: {use_local}")
|
||||
logger.info(f"Use GPU: {use_gpu}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -154,6 +219,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="The port number to run the server on"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-local", action="store_true", help="Use local OCR processing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu", action="store_true", help="Use GPU for local OCR processing"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user