diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index 3276ea0..30286ae 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -7,6 +7,8 @@ from fastapi import APIRouter, FastAPI, Request, HTTPException from memos.schemas import Entity, MetadataType import logging import uvicorn +import os + PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" @@ -44,19 +46,22 @@ def image2base64(img_path): async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None): async with semaphore: # 使用信号量控制并发 - response = await client.post( - f"{endpoint}/v1/chat/completions", - json=request_data, - timeout=60, - headers=headers, - ) - if response.status_code != 200: + try: + response = await client.post( + f"{endpoint}/v1/chat/completions", + json=request_data, + timeout=60, + headers=headers, + ) + response.raise_for_status() + result = response.json() + choices = result.get("choices", []) + if choices and "message" in choices[0] and "content" in choices[0]["message"]: + return choices[0]["message"]["content"] + return "" + except Exception as e: + logger.error(f"Exception occurred: {str(e)}") return None - result = response.json() - choices = result.get("choices", []) - if choices and "message" in choices[0] and "content" in choices[0]["message"]: - return choices[0]["message"]["content"] - return "" async def predict( @@ -65,6 +70,19 @@ async def predict( img_base64 = image2base64(img_path) if not img_base64: return None + + # Get the file extension + _, file_extension = os.path.splitext(img_path) + file_extension = file_extension.lower()[1:] # Remove the dot and convert to lowercase + + # Determine the MIME type + mime_types = { + 'png': 'image/png', + 'jpg': 'image/jpeg', + 'jpeg': 'image/jpeg', + 'webp': 'image/webp' + } + mime_type = mime_types.get(file_extension, 'image/jpeg') request_data = { "model": modelname, @@ -72,12 +90,11 @@ async def predict( { "role": "user", "content": [ - {"type": "text", "text": PROMPT}, { "type": "image_url", - "image_url": {"url": f"data:image/;base64,{img_base64}"}, - "detail": "high", + "image_url": {"url": f"data:{mime_type};base64,{img_base64}"}, }, + {"type": "text", "text": PROMPT}, ], } ], @@ -99,6 +116,7 @@ async def read_root(): return {"healthy": True} +@router.post("", include_in_schema=False) @router.post("/") async def vlm(entity: Entity, request: Request): global modelname, endpoint, token