mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-07 11:45:25 +00:00
feat(ocr): remove local support for gpu
This commit is contained in:
parent
56864baaa0
commit
ece78023f5
@ -29,7 +29,6 @@ class OCRSettings(BaseModel):
|
|||||||
token: str = ""
|
token: str = ""
|
||||||
concurrency: int = 1
|
concurrency: int = 1
|
||||||
use_local: bool = True
|
use_local: bool = True
|
||||||
use_gpu: bool = False
|
|
||||||
force_jpeg: bool = False
|
force_jpeg: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
from fastapi import APIRouter, Request, HTTPException
|
||||||
from memos.schemas import Entity, MetadataType
|
from memos.schemas import Entity, MetadataType
|
||||||
|
|
||||||
METADATA_FIELD_NAME = "ocr_result"
|
METADATA_FIELD_NAME = "ocr_result"
|
||||||
@ -24,7 +24,6 @@ token = None
|
|||||||
concurrency = None
|
concurrency = None
|
||||||
semaphore = None
|
semaphore = None
|
||||||
use_local = False
|
use_local = False
|
||||||
use_gpu = False
|
|
||||||
ocr = None
|
ocr = None
|
||||||
thread_pool = None
|
thread_pool = None
|
||||||
|
|
||||||
@ -170,16 +169,15 @@ async def ocr(entity: Entity, request: Request):
|
|||||||
|
|
||||||
|
|
||||||
def init_plugin(config):
|
def init_plugin(config):
|
||||||
global endpoint, token, concurrency, semaphore, use_local, use_gpu, ocr, thread_pool
|
global endpoint, token, concurrency, semaphore, use_local, ocr, thread_pool
|
||||||
endpoint = config.endpoint
|
endpoint = config.endpoint
|
||||||
token = config.token
|
token = config.token
|
||||||
concurrency = config.concurrency
|
concurrency = config.concurrency
|
||||||
use_local = config.use_local
|
use_local = config.use_local
|
||||||
use_gpu = config.use_gpu
|
|
||||||
semaphore = asyncio.Semaphore(concurrency)
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
if use_local:
|
if use_local:
|
||||||
config_path = os.path.join(os.path.dirname(__file__), "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml")
|
config_path = os.path.join(os.path.dirname(__file__), "ppocr.yaml")
|
||||||
|
|
||||||
# Load and update the config file with absolute model paths
|
# Load and update the config file with absolute model paths
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
@ -203,42 +201,4 @@ def init_plugin(config):
|
|||||||
logger.info(f"Token: {token}")
|
logger.info(f"Token: {token}")
|
||||||
logger.info(f"Concurrency: {concurrency}")
|
logger.info(f"Concurrency: {concurrency}")
|
||||||
logger.info(f"Use local: {use_local}")
|
logger.info(f"Use local: {use_local}")
|
||||||
logger.info(f"Use GPU: {use_gpu}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
import argparse
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="OCR Plugin")
|
|
||||||
parser.add_argument(
|
|
||||||
"--endpoint",
|
|
||||||
type=str,
|
|
||||||
default="http://localhost:8080",
|
|
||||||
help="The endpoint URL for the OCR service",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--token", type=str, default="", help="The token for authentication"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--concurrency", type=int, default=4, help="The concurrency level"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port", type=int, default=8000, help="The port number to run the server on"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-local", action="store_true", help="Use local OCR processing"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-gpu", action="store_true", help="Use GPU for local OCR processing"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
init_plugin(args)
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user