feat(ocr): remove local support for gpu

This commit is contained in:
arkohut 2024-10-04 15:08:43 +08:00
parent 56864baaa0
commit ece78023f5
2 changed files with 3 additions and 44 deletions

View File

@ -29,7 +29,6 @@ class OCRSettings(BaseModel):
token: str = "" token: str = ""
concurrency: int = 1 concurrency: int = 1
use_local: bool = True use_local: bool = True
use_gpu: bool = False
force_jpeg: bool = False force_jpeg: bool = False

View File

@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
import yaml import yaml
from fastapi import APIRouter, FastAPI, Request, HTTPException from fastapi import APIRouter, Request, HTTPException
from memos.schemas import Entity, MetadataType from memos.schemas import Entity, MetadataType
METADATA_FIELD_NAME = "ocr_result" METADATA_FIELD_NAME = "ocr_result"
@ -24,7 +24,6 @@ token = None
concurrency = None concurrency = None
semaphore = None semaphore = None
use_local = False use_local = False
use_gpu = False
ocr = None ocr = None
thread_pool = None thread_pool = None
@ -170,16 +169,15 @@ async def ocr(entity: Entity, request: Request):
def init_plugin(config): def init_plugin(config):
global endpoint, token, concurrency, semaphore, use_local, use_gpu, ocr, thread_pool global endpoint, token, concurrency, semaphore, use_local, ocr, thread_pool
endpoint = config.endpoint endpoint = config.endpoint
token = config.token token = config.token
concurrency = config.concurrency concurrency = config.concurrency
use_local = config.use_local use_local = config.use_local
use_gpu = config.use_gpu
semaphore = asyncio.Semaphore(concurrency) semaphore = asyncio.Semaphore(concurrency)
if use_local: if use_local:
config_path = os.path.join(os.path.dirname(__file__), "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml") config_path = os.path.join(os.path.dirname(__file__), "ppocr.yaml")
# Load and update the config file with absolute model paths # Load and update the config file with absolute model paths
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
@ -203,42 +201,4 @@ def init_plugin(config):
logger.info(f"Token: {token}") logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}") logger.info(f"Concurrency: {concurrency}")
logger.info(f"Use local: {use_local}") logger.info(f"Use local: {use_local}")
logger.info(f"Use GPU: {use_gpu}")
if __name__ == "__main__":
import uvicorn
import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="OCR Plugin")
parser.add_argument(
"--endpoint",
type=str,
default="http://localhost:8080",
help="The endpoint URL for the OCR service",
)
parser.add_argument(
"--token", type=str, default="", help="The token for authentication"
)
parser.add_argument(
"--concurrency", type=int, default=4, help="The concurrency level"
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to run the server on"
)
parser.add_argument(
"--use-local", action="store_true", help="Use local OCR processing"
)
parser.add_argument(
"--use-gpu", action="store_true", help="Use GPU for local OCR processing"
)
args = parser.parse_args()
init_plugin(args)
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)