fix(ollama): make vlm plugin support ollama

This commit is contained in:
arkohut 2024-08-25 16:49:39 +08:00
parent 67a5e10d3e
commit b10b080800

View File

@ -7,6 +7,8 @@ from fastapi import APIRouter, FastAPI, Request, HTTPException
from memos.schemas import Entity, MetadataType
import logging
import uvicorn
import os
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@ -44,19 +46,22 @@ def image2base64(img_path):
async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
async with semaphore: # 使用信号量控制并发
response = await client.post(
f"{endpoint}/v1/chat/completions",
json=request_data,
timeout=60,
headers=headers,
)
if response.status_code != 200:
try:
response = await client.post(
f"{endpoint}/v1/chat/completions",
json=request_data,
timeout=60,
headers=headers,
)
response.raise_for_status()
result = response.json()
choices = result.get("choices", [])
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
return choices[0]["message"]["content"]
return ""
except Exception as e:
logger.error(f"Exception occurred: {str(e)}")
return None
result = response.json()
choices = result.get("choices", [])
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
return choices[0]["message"]["content"]
return ""
async def predict(
@ -65,6 +70,19 @@ async def predict(
img_base64 = image2base64(img_path)
if not img_base64:
return None
# Get the file extension
_, file_extension = os.path.splitext(img_path)
file_extension = file_extension.lower()[1:] # Remove the dot and convert to lowercase
# Determine the MIME type
mime_types = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'webp': 'image/webp'
}
mime_type = mime_types.get(file_extension, 'image/jpeg')
request_data = {
"model": modelname,
@ -72,12 +90,11 @@ async def predict(
{
"role": "user",
"content": [
{"type": "text", "text": PROMPT},
{
"type": "image_url",
"image_url": {"url": f"data:image/;base64,{img_base64}"},
"detail": "high",
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
},
{"type": "text", "text": PROMPT},
],
}
],
@ -99,6 +116,7 @@ async def read_root():
return {"healthy": True}
@router.post("", include_in_schema=False)
@router.post("/")
async def vlm(entity: Entity, request: Request):
global modelname, endpoint, token