Revert "feat: remove local inference for ocr and vlm"

This reverts commit 228cdee66fc7d63ff4a1ae73a9c8048ddd219293.
This commit is contained in:
arkohut 2024-10-04 14:48:45 +08:00
parent 8af5fbacd1
commit 56864baaa0
3 changed files with 159 additions and 3 deletions

View File

@ -19,6 +19,8 @@ class VLMSettings(BaseModel):
token: str = "" token: str = ""
concurrency: int = 1 concurrency: int = 1
force_jpeg: bool = False force_jpeg: bool = False
use_local: bool = True
use_modelscope: bool = False
class OCRSettings(BaseModel): class OCRSettings(BaseModel):
@ -26,6 +28,8 @@ class OCRSettings(BaseModel):
endpoint: str = "http://localhost:5555/predict" endpoint: str = "http://localhost:5555/predict"
token: str = "" token: str = ""
concurrency: int = 1 concurrency: int = 1
use_local: bool = True
use_gpu: bool = False
force_jpeg: bool = False force_jpeg: bool = False

View File

@ -6,6 +6,11 @@ import httpx
import json import json
import base64 import base64
from PIL import Image from PIL import Image
import numpy as np
from rapidocr_onnxruntime import RapidOCR
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import yaml
from fastapi import APIRouter, FastAPI, Request, HTTPException from fastapi import APIRouter, FastAPI, Request, HTTPException
from memos.schemas import Entity, MetadataType from memos.schemas import Entity, MetadataType
@ -18,6 +23,10 @@ endpoint = None
token = None token = None
concurrency = None concurrency = None
semaphore = None semaphore = None
use_local = False
use_gpu = False
ocr = None
thread_pool = None
# Configure logger # Configure logger
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -49,7 +58,39 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N
return response.json() return response.json()
def convert_ocr_results(results):
if results is None:
return []
converted = []
for result in results:
item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]}
converted.append(item)
return converted
def predict_local(img_path):
try:
with Image.open(img_path) as img:
img_array = np.array(img)
results, _ = ocr(img_array)
return convert_ocr_results(results)
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
return None
async def async_predict_local(img_path):
loop = asyncio.get_running_loop()
results = await loop.run_in_executor(thread_pool, partial(predict_local, img_path))
return results
# Modify the predict function to use semaphore
async def predict(img_path): async def predict(img_path):
if use_local:
return await async_predict_local(img_path)
image_base64 = image2base64(img_path) image_base64 = image2base64(img_path)
if not image_base64: if not image_base64:
return None return None
@ -129,16 +170,41 @@ async def ocr(entity: Entity, request: Request):
def init_plugin(config): def init_plugin(config):
global endpoint, token, concurrency, semaphore global endpoint, token, concurrency, semaphore, use_local, use_gpu, ocr, thread_pool
endpoint = config.endpoint endpoint = config.endpoint
token = config.token token = config.token
concurrency = config.concurrency concurrency = config.concurrency
use_local = config.use_local
use_gpu = config.use_gpu
semaphore = asyncio.Semaphore(concurrency) semaphore = asyncio.Semaphore(concurrency)
if use_local:
config_path = os.path.join(os.path.dirname(__file__), "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml")
# Load and update the config file with absolute model paths
with open(config_path, 'r') as f:
ocr_config = yaml.safe_load(f)
model_dir = os.path.join(os.path.dirname(__file__), "models")
ocr_config['Det']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Det']['model_path']))
ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path']))
ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path']))
# Save the updated config to a temporary file with strings wrapped in double quotes
temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml")
with open(temp_config_path, 'w') as f:
yaml.safe_dump(ocr_config, f)
ocr = RapidOCR(config_path=temp_config_path)
thread_pool = ThreadPoolExecutor(max_workers=concurrency)
logger.info("OCR plugin initialized") logger.info("OCR plugin initialized")
logger.info(f"Endpoint: {endpoint}") logger.info(f"Endpoint: {endpoint}")
logger.info(f"Token: {token}") logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}") logger.info(f"Concurrency: {concurrency}")
logger.info(f"Use local: {use_local}")
logger.info(f"Use GPU: {use_gpu}")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
@ -161,6 +227,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--port", type=int, default=8000, help="The port number to run the server on" "--port", type=int, default=8000, help="The port number to run the server on"
) )
parser.add_argument(
"--use-local", action="store_true", help="Use local OCR processing"
)
parser.add_argument(
"--use-gpu", action="store_true", help="Use GPU for local OCR processing"
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -83,7 +83,44 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
async def predict( async def predict(
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
) -> Optional[str]: ) -> Optional[str]:
return await predict_remote(endpoint, modelname, img_path, token) if use_local:
return await predict_local(img_path)
else:
return await predict_remote(endpoint, modelname, img_path, token)
async def predict_local(img_path: str) -> Optional[str]:
try:
image = Image.open(img_path)
task_prompt = "<MORE_DETAILED_CAPTION>"
prompt = task_prompt + ""
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(
florence_model.device, torch_dtype
)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3,
)
generated_texts = florence_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
parsed_answer = florence_processor.post_process_generation(
generated_texts[0],
task=task_prompt,
image_size=(image.width, image.height),
)
return parsed_answer.get(task_prompt, "")
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
return None
async def predict_remote( async def predict_remote(
@ -210,15 +247,55 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config): def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore, force_jpeg global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
modelname = config.modelname modelname = config.modelname
endpoint = config.endpoint endpoint = config.endpoint
token = config.token token = config.token
concurrency = config.concurrency concurrency = config.concurrency
force_jpeg = config.force_jpeg force_jpeg = config.force_jpeg
use_local = config.use_local
use_modelscope = config.use_modelscope
semaphore = asyncio.Semaphore(concurrency) semaphore = asyncio.Semaphore(concurrency)
if use_local:
# 检测可用的设备
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
torch_dtype = (
torch.float32
if (
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6
)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
logger.info(f"Using device: {device}")
if use_modelscope:
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = "microsoft/Florence-2-base-ft"
logger.info(f"Using model: {model_dir}")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
model_dir, trust_remote_code=True
)
logger.info("Florence model and processor initialized")
# Print the parameters # Print the parameters
logger.info("VLM plugin initialized") logger.info("VLM plugin initialized")
logger.info(f"Model Name: {modelname}") logger.info(f"Model Name: {modelname}")
@ -226,6 +303,8 @@ def init_plugin(config):
logger.info(f"Token: {token}") logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}") logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}") logger.info(f"Force JPEG: {force_jpeg}")
logger.info(f"Use Local: {use_local}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__": if __name__ == "__main__":
@ -244,6 +323,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on" "--port", type=int, default=8000, help="Port to run the server on"
) )
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
args = parser.parse_args() args = parser.parse_args()