diff --git a/memos/config.py b/memos/config.py index 458ee8d..db5ffbc 100644 --- a/memos/config.py +++ b/memos/config.py @@ -19,6 +19,7 @@ class VLMSettings(BaseModel): token: str = "" concurrency: int = 4 force_jpeg: bool = False + use_local: bool = True class OCRSettings(BaseModel): diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py index 399d23a..a0b207d 100644 --- a/memos/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -145,6 +145,7 @@ async def ocr(entity: Entity, request: Request): } ] }, + timeout=30, ) # Check if the patch request was successful diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index f24ea70..ad636b6 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -9,6 +9,8 @@ import logging import uvicorn import os import io +import torch +from transformers import AutoModelForCausalLM, AutoProcessor PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" @@ -21,6 +23,10 @@ token = None concurrency = None semaphore = None force_jpeg = None +use_local = None +florence_model = None +florence_processor = None +torch_dtype = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -35,18 +41,18 @@ def image2base64(img_path): with Image.open(img_path) as img: if force_jpeg: # Convert image to RGB mode (removes alpha channel if present) - img = img.convert('RGB') + img = img.convert("RGB") # Save as JPEG in memory buffer = io.BytesIO() - img.save(buffer, format='JPEG') + img.save(buffer, format="JPEG") buffer.seek(0) - encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8') + encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8") else: # Use original format buffer = io.BytesIO() img.save(buffer, format=img.format) buffer.seek(0) - encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8') + encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8") return encoded_string except Exception as e: logger.error(f"Error processing image {img_path}: {str(e)}") @@ -79,12 +85,57 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N async def predict( endpoint: str, modelname: str, img_path: str, token: Optional[str] = None +) -> Optional[str]: + if use_local: + return await predict_local(img_path) + else: + return await predict_remote(endpoint, modelname, img_path, token) + + +async def predict_local(img_path: str) -> Optional[str]: + try: + image = Image.open(img_path) + task_prompt = "" + prompt = task_prompt + "" + + inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to( + florence_model.device, torch_dtype + ) + + generated_ids = florence_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + do_sample=False, + num_beams=3, + ) + + generated_texts = florence_processor.batch_decode( + generated_ids, skip_special_tokens=False + ) + + parsed_answer = florence_processor.post_process_generation( + generated_texts[0], + task=task_prompt, + image_size=(image.width, image.height), + ) + + return parsed_answer.get(task_prompt, "") + except Exception as e: + logger.error(f"Error processing image {img_path}: {str(e)}") + return None + + +async def predict_remote( + endpoint: str, modelname: str, img_path: str, token: Optional[str] = None ) -> Optional[str]: img_base64 = image2base64(img_path) if not img_base64: return None - mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True + mime_type = ( + "image/jpeg" if force_jpeg else "image/jpeg" + ) # Default to JPEG if force_jpeg is True if not force_jpeg: # Only determine MIME type if not forcing JPEG @@ -167,9 +218,9 @@ async def vlm(entity: Entity, request: Request): vlm_result = await predict(endpoint, modelname, entity.filepath, token=token) - print(vlm_result) + logger.info(vlm_result) if not vlm_result: - print(f"No VLM result found for file: {entity.filepath}") + logger.info(f"No VLM result found for file: {entity.filepath}") return {metadata_field_name: "{}"} async with httpx.AsyncClient() as client: @@ -199,14 +250,46 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): - global modelname, endpoint, token, concurrency, semaphore, force_jpeg + global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype + modelname = config.modelname endpoint = config.endpoint token = config.token concurrency = config.concurrency force_jpeg = config.force_jpeg + use_local = config.use_local semaphore = asyncio.Semaphore(concurrency) + if use_local: + # 检测可用的设备 + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + torch_dtype = ( + torch.float32 + if ( + torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6 + ) + or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) + else torch.float16 + ) + logger.info(f"Using device: {device}") + + florence_model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, + ).to(device) + florence_processor = AutoProcessor.from_pretrained( + "microsoft/Florence-2-base-ft", trust_remote_code=True + ) + logger.info("Florence model and processor initialized") + # Print the parameters logger.info("VLM plugin initialized") logger.info(f"Model Name: {modelname}") @@ -214,6 +297,7 @@ def init_plugin(config): logger.info(f"Token: {token}") logger.info(f"Concurrency: {concurrency}") logger.info(f"Force JPEG: {force_jpeg}") + logger.info(f"Use Local: {use_local}") if __name__ == "__main__": diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py index 7a60825..b936668 100644 --- a/memos_ml_backends/server.py +++ b/memos_ml_backends/server.py @@ -50,7 +50,7 @@ if use_florence_model: torch_dtype=torch_dtype, attn_implementation="sdpa", trust_remote_code=True, - ).to(device) + ).to(device, torch_dtype) florence_processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base-ft", trust_remote_code=True ) @@ -60,7 +60,7 @@ else: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", torch_dtype=torch_dtype, device_map="auto", - ).to(device) + ).to(device, torch_dtype) qwen2vl_processor = AutoProcessor.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4" )