mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-07 03:35:24 +00:00
Revert "feat: remove local inference for ocr and vlm"
This reverts commit 228cdee66fc7d63ff4a1ae73a9c8048ddd219293.
This commit is contained in:
parent
8af5fbacd1
commit
56864baaa0
@ -19,6 +19,8 @@ class VLMSettings(BaseModel):
|
|||||||
token: str = ""
|
token: str = ""
|
||||||
concurrency: int = 1
|
concurrency: int = 1
|
||||||
force_jpeg: bool = False
|
force_jpeg: bool = False
|
||||||
|
use_local: bool = True
|
||||||
|
use_modelscope: bool = False
|
||||||
|
|
||||||
|
|
||||||
class OCRSettings(BaseModel):
|
class OCRSettings(BaseModel):
|
||||||
@ -26,6 +28,8 @@ class OCRSettings(BaseModel):
|
|||||||
endpoint: str = "http://localhost:5555/predict"
|
endpoint: str = "http://localhost:5555/predict"
|
||||||
token: str = ""
|
token: str = ""
|
||||||
concurrency: int = 1
|
concurrency: int = 1
|
||||||
|
use_local: bool = True
|
||||||
|
use_gpu: bool = False
|
||||||
force_jpeg: bool = False
|
force_jpeg: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,6 +6,11 @@ import httpx
|
|||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from rapidocr_onnxruntime import RapidOCR
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
import yaml
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
||||||
from memos.schemas import Entity, MetadataType
|
from memos.schemas import Entity, MetadataType
|
||||||
@ -18,6 +23,10 @@ endpoint = None
|
|||||||
token = None
|
token = None
|
||||||
concurrency = None
|
concurrency = None
|
||||||
semaphore = None
|
semaphore = None
|
||||||
|
use_local = False
|
||||||
|
use_gpu = False
|
||||||
|
ocr = None
|
||||||
|
thread_pool = None
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -49,7 +58,39 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ocr_results(results):
|
||||||
|
if results is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
converted = []
|
||||||
|
for result in results:
|
||||||
|
item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]}
|
||||||
|
converted.append(item)
|
||||||
|
return converted
|
||||||
|
|
||||||
|
|
||||||
|
def predict_local(img_path):
|
||||||
|
try:
|
||||||
|
with Image.open(img_path) as img:
|
||||||
|
img_array = np.array(img)
|
||||||
|
results, _ = ocr(img_array)
|
||||||
|
return convert_ocr_results(results)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def async_predict_local(img_path):
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
results = await loop.run_in_executor(thread_pool, partial(predict_local, img_path))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# Modify the predict function to use semaphore
|
||||||
async def predict(img_path):
|
async def predict(img_path):
|
||||||
|
if use_local:
|
||||||
|
return await async_predict_local(img_path)
|
||||||
|
|
||||||
image_base64 = image2base64(img_path)
|
image_base64 = image2base64(img_path)
|
||||||
if not image_base64:
|
if not image_base64:
|
||||||
return None
|
return None
|
||||||
@ -129,16 +170,41 @@ async def ocr(entity: Entity, request: Request):
|
|||||||
|
|
||||||
|
|
||||||
def init_plugin(config):
|
def init_plugin(config):
|
||||||
global endpoint, token, concurrency, semaphore
|
global endpoint, token, concurrency, semaphore, use_local, use_gpu, ocr, thread_pool
|
||||||
endpoint = config.endpoint
|
endpoint = config.endpoint
|
||||||
token = config.token
|
token = config.token
|
||||||
concurrency = config.concurrency
|
concurrency = config.concurrency
|
||||||
|
use_local = config.use_local
|
||||||
|
use_gpu = config.use_gpu
|
||||||
semaphore = asyncio.Semaphore(concurrency)
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
if use_local:
|
||||||
|
config_path = os.path.join(os.path.dirname(__file__), "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml")
|
||||||
|
|
||||||
|
# Load and update the config file with absolute model paths
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
ocr_config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
model_dir = os.path.join(os.path.dirname(__file__), "models")
|
||||||
|
ocr_config['Det']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Det']['model_path']))
|
||||||
|
ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path']))
|
||||||
|
ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path']))
|
||||||
|
|
||||||
|
# Save the updated config to a temporary file with strings wrapped in double quotes
|
||||||
|
temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml")
|
||||||
|
with open(temp_config_path, 'w') as f:
|
||||||
|
yaml.safe_dump(ocr_config, f)
|
||||||
|
|
||||||
|
ocr = RapidOCR(config_path=temp_config_path)
|
||||||
|
thread_pool = ThreadPoolExecutor(max_workers=concurrency)
|
||||||
|
|
||||||
logger.info("OCR plugin initialized")
|
logger.info("OCR plugin initialized")
|
||||||
logger.info(f"Endpoint: {endpoint}")
|
logger.info(f"Endpoint: {endpoint}")
|
||||||
logger.info(f"Token: {token}")
|
logger.info(f"Token: {token}")
|
||||||
logger.info(f"Concurrency: {concurrency}")
|
logger.info(f"Concurrency: {concurrency}")
|
||||||
|
logger.info(f"Use local: {use_local}")
|
||||||
|
logger.info(f"Use GPU: {use_gpu}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -161,6 +227,12 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=8000, help="The port number to run the server on"
|
"--port", type=int, default=8000, help="The port number to run the server on"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-local", action="store_true", help="Use local OCR processing"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-gpu", action="store_true", help="Use GPU for local OCR processing"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -83,7 +83,44 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
|
|||||||
async def predict(
|
async def predict(
|
||||||
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
return await predict_remote(endpoint, modelname, img_path, token)
|
if use_local:
|
||||||
|
return await predict_local(img_path)
|
||||||
|
else:
|
||||||
|
return await predict_remote(endpoint, modelname, img_path, token)
|
||||||
|
|
||||||
|
|
||||||
|
async def predict_local(img_path: str) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
image = Image.open(img_path)
|
||||||
|
task_prompt = "<MORE_DETAILED_CAPTION>"
|
||||||
|
prompt = task_prompt + ""
|
||||||
|
|
||||||
|
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(
|
||||||
|
florence_model.device, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_ids = florence_model.generate(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
pixel_values=inputs["pixel_values"],
|
||||||
|
max_new_tokens=1024,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_texts = florence_processor.batch_decode(
|
||||||
|
generated_ids, skip_special_tokens=False
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_answer = florence_processor.post_process_generation(
|
||||||
|
generated_texts[0],
|
||||||
|
task=task_prompt,
|
||||||
|
image_size=(image.width, image.height),
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed_answer.get(task_prompt, "")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def predict_remote(
|
async def predict_remote(
|
||||||
@ -210,15 +247,55 @@ async def vlm(entity: Entity, request: Request):
|
|||||||
|
|
||||||
|
|
||||||
def init_plugin(config):
|
def init_plugin(config):
|
||||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
|
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
|
||||||
|
|
||||||
modelname = config.modelname
|
modelname = config.modelname
|
||||||
endpoint = config.endpoint
|
endpoint = config.endpoint
|
||||||
token = config.token
|
token = config.token
|
||||||
concurrency = config.concurrency
|
concurrency = config.concurrency
|
||||||
force_jpeg = config.force_jpeg
|
force_jpeg = config.force_jpeg
|
||||||
|
use_local = config.use_local
|
||||||
|
use_modelscope = config.use_modelscope
|
||||||
semaphore = asyncio.Semaphore(concurrency)
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
if use_local:
|
||||||
|
# 检测可用的设备
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
torch_dtype = (
|
||||||
|
torch.float32
|
||||||
|
if (
|
||||||
|
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6
|
||||||
|
)
|
||||||
|
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
|
||||||
|
else torch.float16
|
||||||
|
)
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
if use_modelscope:
|
||||||
|
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
|
||||||
|
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
||||||
|
else:
|
||||||
|
model_dir = "microsoft/Florence-2-base-ft"
|
||||||
|
logger.info(f"Using model: {model_dir}")
|
||||||
|
|
||||||
|
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||||
|
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
attn_implementation="sdpa",
|
||||||
|
trust_remote_code=True,
|
||||||
|
).to(device)
|
||||||
|
florence_processor = AutoProcessor.from_pretrained(
|
||||||
|
model_dir, trust_remote_code=True
|
||||||
|
)
|
||||||
|
logger.info("Florence model and processor initialized")
|
||||||
|
|
||||||
# Print the parameters
|
# Print the parameters
|
||||||
logger.info("VLM plugin initialized")
|
logger.info("VLM plugin initialized")
|
||||||
logger.info(f"Model Name: {modelname}")
|
logger.info(f"Model Name: {modelname}")
|
||||||
@ -226,6 +303,8 @@ def init_plugin(config):
|
|||||||
logger.info(f"Token: {token}")
|
logger.info(f"Token: {token}")
|
||||||
logger.info(f"Concurrency: {concurrency}")
|
logger.info(f"Concurrency: {concurrency}")
|
||||||
logger.info(f"Force JPEG: {force_jpeg}")
|
logger.info(f"Force JPEG: {force_jpeg}")
|
||||||
|
logger.info(f"Use Local: {use_local}")
|
||||||
|
logger.info(f"Use ModelScope: {use_modelscope}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -244,6 +323,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=8000, help="Port to run the server on"
|
"--port", type=int, default=8000, help="Port to run the server on"
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user