From ece78023f503ce7a67c9bafda13fca3fd8ca5381 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Fri, 4 Oct 2024 15:08:43 +0800 Subject: [PATCH] feat(ocr): remove local support for gpu --- memos/config.py | 1 - memos/plugins/ocr/main.py | 46 +++------------------------------------ 2 files changed, 3 insertions(+), 44 deletions(-) diff --git a/memos/config.py b/memos/config.py index b979852..17f8b6b 100644 --- a/memos/config.py +++ b/memos/config.py @@ -29,7 +29,6 @@ class OCRSettings(BaseModel): token: str = "" concurrency: int = 1 use_local: bool = True - use_gpu: bool = False force_jpeg: bool = False diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py index 910c77d..4d1b891 100644 --- a/memos/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial import yaml -from fastapi import APIRouter, FastAPI, Request, HTTPException +from fastapi import APIRouter, Request, HTTPException from memos.schemas import Entity, MetadataType METADATA_FIELD_NAME = "ocr_result" @@ -24,7 +24,6 @@ token = None concurrency = None semaphore = None use_local = False -use_gpu = False ocr = None thread_pool = None @@ -170,16 +169,15 @@ async def ocr(entity: Entity, request: Request): def init_plugin(config): - global endpoint, token, concurrency, semaphore, use_local, use_gpu, ocr, thread_pool + global endpoint, token, concurrency, semaphore, use_local, ocr, thread_pool endpoint = config.endpoint token = config.token concurrency = config.concurrency use_local = config.use_local - use_gpu = config.use_gpu semaphore = asyncio.Semaphore(concurrency) if use_local: - config_path = os.path.join(os.path.dirname(__file__), "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml") + config_path = os.path.join(os.path.dirname(__file__), "ppocr.yaml") # Load and update the config file with absolute model paths with open(config_path, 'r') as f: @@ -203,42 +201,4 @@ def init_plugin(config): logger.info(f"Token: {token}") logger.info(f"Concurrency: {concurrency}") logger.info(f"Use local: {use_local}") - logger.info(f"Use GPU: {use_gpu}") - -if __name__ == "__main__": - import uvicorn - import argparse - from fastapi import FastAPI - - parser = argparse.ArgumentParser(description="OCR Plugin") - parser.add_argument( - "--endpoint", - type=str, - default="http://localhost:8080", - help="The endpoint URL for the OCR service", - ) - parser.add_argument( - "--token", type=str, default="", help="The token for authentication" - ) - parser.add_argument( - "--concurrency", type=int, default=4, help="The concurrency level" - ) - parser.add_argument( - "--port", type=int, default=8000, help="The port number to run the server on" - ) - parser.add_argument( - "--use-local", action="store_true", help="Use local OCR processing" - ) - parser.add_argument( - "--use-gpu", action="store_true", help="Use GPU for local OCR processing" - ) - - args = parser.parse_args() - - init_plugin(args) - - app = FastAPI() - app.include_router(router) - - uvicorn.run(app, host="0.0.0.0", port=args.port)