diff --git a/memos/config.py b/memos/config.py index 4257d33..5328a20 100644 --- a/memos/config.py +++ b/memos/config.py @@ -25,12 +25,14 @@ class OCRSettings(BaseModel): endpoint: str = "http://localhost:5555/predict" token: str = "" concurrency: int = 4 + use_local: bool = True + use_gpu: bool = False class EmbeddingSettings(BaseModel): num_dim: int = 768 ollama_endpoint: str = "http://localhost:11434" - ollama_model: str = "jina/jina-embeddings-v2-base-en" + ollama_model: str = "nextfire/paraphrase-multilingual-minilm" class Settings(BaseSettings): diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py index 674f4e8..399d23a 100644 --- a/memos/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -1,10 +1,16 @@ import asyncio import logging +import os from typing import Optional import httpx import json import base64 from PIL import Image +import numpy as np +from rapidocr_onnxruntime import RapidOCR +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import yaml from fastapi import APIRouter, FastAPI, Request, HTTPException from memos.schemas import Entity, MetadataType @@ -17,6 +23,10 @@ endpoint = None token = None concurrency = None semaphore = None +use_local = False +use_gpu = False +ocr = None +thread_pool = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -48,8 +58,39 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N return response.json() +def convert_ocr_results(results): + if results is None: + return [] + + converted = [] + for result in results: + item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]} + converted.append(item) + return converted + + +def predict_local(img_path): + try: + with Image.open(img_path) as img: + img_array = np.array(img) + results, _ = ocr(img_array) + return convert_ocr_results(results) + except Exception as e: + logger.error(f"Error processing image {img_path}: {str(e)}") + return None + + +async def async_predict_local(img_path): + loop = asyncio.get_running_loop() + results = await loop.run_in_executor(thread_pool, partial(predict_local, img_path)) + return results + + # Modify the predict function to use semaphore async def predict(img_path): + if use_local: + return await async_predict_local(img_path) + image_base64 = image2base64(img_path) if not image_base64: return None @@ -121,16 +162,40 @@ async def ocr(entity: Entity, request: Request): def init_plugin(config): - global endpoint, token, concurrency, semaphore + global endpoint, token, concurrency, semaphore, use_local, use_gpu, ocr, thread_pool endpoint = config.endpoint token = config.token concurrency = config.concurrency + use_local = config.use_local + use_gpu = config.use_gpu semaphore = asyncio.Semaphore(concurrency) + + if use_local: + config_path = os.path.join(os.path.dirname(__file__), "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml") + + # Load and update the config file with absolute model paths + with open(config_path, 'r') as f: + ocr_config = yaml.safe_load(f) + + model_dir = os.path.join(os.path.dirname(__file__), "models") + ocr_config['Det']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Det']['model_path'])) + ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path'])) + ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path'])) + + # Save the updated config to a temporary file + temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml") + with open(temp_config_path, 'w') as f: + yaml.safe_dump(ocr_config, f) + + ocr = RapidOCR(config_path=temp_config_path) + thread_pool = ThreadPoolExecutor(max_workers=concurrency) logger.info("OCR plugin initialized") logger.info(f"Endpoint: {endpoint}") logger.info(f"Token: {token}") logger.info(f"Concurrency: {concurrency}") + logger.info(f"Use local: {use_local}") + logger.info(f"Use GPU: {use_gpu}") if __name__ == "__main__": @@ -154,6 +219,12 @@ if __name__ == "__main__": parser.add_argument( "--port", type=int, default=8000, help="The port number to run the server on" ) + parser.add_argument( + "--use-local", action="store_true", help="Use local OCR processing" + ) + parser.add_argument( + "--use-gpu", action="store_true", help="Use GPU for local OCR processing" + ) args = parser.parse_args()