From 228cdee66fc7d63ff4a1ae73a9c8048ddd219293 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Fri, 27 Sep 2024 23:29:52 +0800 Subject: [PATCH] feat: remove local inference for ocr and vlm --- memos/config.py | 4 -- memos/plugins/ocr/main.py | 74 +--------------------------------- memos/plugins/vlm/main.py | 84 +-------------------------------------- 3 files changed, 3 insertions(+), 159 deletions(-) diff --git a/memos/config.py b/memos/config.py index b979852..f27b962 100644 --- a/memos/config.py +++ b/memos/config.py @@ -19,8 +19,6 @@ class VLMSettings(BaseModel): token: str = "" concurrency: int = 1 force_jpeg: bool = False - use_local: bool = True - use_modelscope: bool = False class OCRSettings(BaseModel): @@ -28,8 +26,6 @@ class OCRSettings(BaseModel): endpoint: str = "http://localhost:5555/predict" token: str = "" concurrency: int = 1 - use_local: bool = True - use_gpu: bool = False force_jpeg: bool = False diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py index 910c77d..8716afc 100644 --- a/memos/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -6,11 +6,6 @@ import httpx import json import base64 from PIL import Image -import numpy as np -from rapidocr_onnxruntime import RapidOCR -from concurrent.futures import ThreadPoolExecutor -from functools import partial -import yaml from fastapi import APIRouter, FastAPI, Request, HTTPException from memos.schemas import Entity, MetadataType @@ -23,10 +18,6 @@ endpoint = None token = None concurrency = None semaphore = None -use_local = False -use_gpu = False -ocr = None -thread_pool = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -58,39 +49,7 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N return response.json() -def convert_ocr_results(results): - if results is None: - return [] - - converted = [] - for result in results: - item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]} - converted.append(item) - return converted - - -def predict_local(img_path): - try: - with Image.open(img_path) as img: - img_array = np.array(img) - results, _ = ocr(img_array) - return convert_ocr_results(results) - except Exception as e: - logger.error(f"Error processing image {img_path}: {str(e)}") - return None - - -async def async_predict_local(img_path): - loop = asyncio.get_running_loop() - results = await loop.run_in_executor(thread_pool, partial(predict_local, img_path)) - return results - - -# Modify the predict function to use semaphore async def predict(img_path): - if use_local: - return await async_predict_local(img_path) - image_base64 = image2base64(img_path) if not image_base64: return None @@ -170,41 +129,16 @@ async def ocr(entity: Entity, request: Request): def init_plugin(config): - global endpoint, token, concurrency, semaphore, use_local, use_gpu, ocr, thread_pool + global endpoint, token, concurrency, semaphore endpoint = config.endpoint token = config.token concurrency = config.concurrency - use_local = config.use_local - use_gpu = config.use_gpu semaphore = asyncio.Semaphore(concurrency) - - if use_local: - config_path = os.path.join(os.path.dirname(__file__), "ppocr-gpu.yaml" if use_gpu else "ppocr.yaml") - - # Load and update the config file with absolute model paths - with open(config_path, 'r') as f: - ocr_config = yaml.safe_load(f) - - model_dir = os.path.join(os.path.dirname(__file__), "models") - ocr_config['Det']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Det']['model_path'])) - ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path'])) - ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path'])) - - # Save the updated config to a temporary file with strings wrapped in double quotes - temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml") - with open(temp_config_path, 'w') as f: - yaml.safe_dump(ocr_config, f) - - ocr = RapidOCR(config_path=temp_config_path) - thread_pool = ThreadPoolExecutor(max_workers=concurrency) logger.info("OCR plugin initialized") logger.info(f"Endpoint: {endpoint}") logger.info(f"Token: {token}") logger.info(f"Concurrency: {concurrency}") - logger.info(f"Use local: {use_local}") - logger.info(f"Use GPU: {use_gpu}") - if __name__ == "__main__": import uvicorn @@ -227,12 +161,6 @@ if __name__ == "__main__": parser.add_argument( "--port", type=int, default=8000, help="The port number to run the server on" ) - parser.add_argument( - "--use-local", action="store_true", help="Use local OCR processing" - ) - parser.add_argument( - "--use-gpu", action="store_true", help="Use GPU for local OCR processing" - ) args = parser.parse_args() diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index fc905a4..cd852aa 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -98,44 +98,7 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N async def predict( endpoint: str, modelname: str, img_path: str, token: Optional[str] = None ) -> Optional[str]: - if use_local: - return await predict_local(img_path) - else: - return await predict_remote(endpoint, modelname, img_path, token) - - -async def predict_local(img_path: str) -> Optional[str]: - try: - image = Image.open(img_path) - task_prompt = "" - prompt = task_prompt + "" - - inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to( - florence_model.device, torch_dtype - ) - - generated_ids = florence_model.generate( - input_ids=inputs["input_ids"], - pixel_values=inputs["pixel_values"], - max_new_tokens=1024, - do_sample=False, - num_beams=3, - ) - - generated_texts = florence_processor.batch_decode( - generated_ids, skip_special_tokens=False - ) - - parsed_answer = florence_processor.post_process_generation( - generated_texts[0], - task=task_prompt, - image_size=(image.width, image.height), - ) - - return parsed_answer.get(task_prompt, "") - except Exception as e: - logger.error(f"Error processing image {img_path}: {str(e)}") - return None + return await predict_remote(endpoint, modelname, img_path, token) async def predict_remote( @@ -262,55 +225,15 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): - global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype + global modelname, endpoint, token, concurrency, semaphore, force_jpeg modelname = config.modelname endpoint = config.endpoint token = config.token concurrency = config.concurrency force_jpeg = config.force_jpeg - use_local = config.use_local - use_modelscope = config.use_modelscope semaphore = asyncio.Semaphore(concurrency) - if use_local: - # 检测可用的设备 - if torch.cuda.is_available(): - device = torch.device("cuda") - elif torch.backends.mps.is_available(): - device = torch.device("mps") - else: - device = torch.device("cpu") - - torch_dtype = ( - torch.float32 - if ( - torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6 - ) - or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) - else torch.float16 - ) - logger.info(f"Using device: {device}") - - if use_modelscope: - model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft') - logger.info(f"Model downloaded from ModelScope to: {model_dir}") - else: - model_dir = "microsoft/Florence-2-base-ft" - logger.info(f"Using model: {model_dir}") - - with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): - florence_model = AutoModelForCausalLM.from_pretrained( - model_dir, - torch_dtype=torch_dtype, - attn_implementation="sdpa", - trust_remote_code=True, - ).to(device) - florence_processor = AutoProcessor.from_pretrained( - model_dir, trust_remote_code=True - ) - logger.info("Florence model and processor initialized") - # Print the parameters logger.info("VLM plugin initialized") logger.info(f"Model Name: {modelname}") @@ -318,8 +241,6 @@ def init_plugin(config): logger.info(f"Token: {token}") logger.info(f"Concurrency: {concurrency}") logger.info(f"Force JPEG: {force_jpeg}") - logger.info(f"Use Local: {use_local}") - logger.info(f"Use ModelScope: {use_modelscope}") if __name__ == "__main__": @@ -338,7 +259,6 @@ if __name__ == "__main__": parser.add_argument( "--port", type=int, default=8000, help="Port to run the server on" ) - parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model") args = parser.parse_args()