mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-07 03:35:24 +00:00
fix(ollama): make vlm plugin support ollama
This commit is contained in:
parent
67a5e10d3e
commit
b10b080800
@ -7,6 +7,8 @@ from fastapi import APIRouter, FastAPI, Request, HTTPException
|
||||
from memos.schemas import Entity, MetadataType
|
||||
import logging
|
||||
import uvicorn
|
||||
import os
|
||||
|
||||
|
||||
PLUGIN_NAME = "vlm"
|
||||
PROMPT = "描述这张图片的内容"
|
||||
@ -44,19 +46,22 @@ def image2base64(img_path):
|
||||
|
||||
async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
|
||||
async with semaphore: # 使用信号量控制并发
|
||||
response = await client.post(
|
||||
f"{endpoint}/v1/chat/completions",
|
||||
json=request_data,
|
||||
timeout=60,
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{endpoint}/v1/chat/completions",
|
||||
json=request_data,
|
||||
timeout=60,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
choices = result.get("choices", [])
|
||||
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
|
||||
return choices[0]["message"]["content"]
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Exception occurred: {str(e)}")
|
||||
return None
|
||||
result = response.json()
|
||||
choices = result.get("choices", [])
|
||||
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
|
||||
return choices[0]["message"]["content"]
|
||||
return ""
|
||||
|
||||
|
||||
async def predict(
|
||||
@ -65,6 +70,19 @@ async def predict(
|
||||
img_base64 = image2base64(img_path)
|
||||
if not img_base64:
|
||||
return None
|
||||
|
||||
# Get the file extension
|
||||
_, file_extension = os.path.splitext(img_path)
|
||||
file_extension = file_extension.lower()[1:] # Remove the dot and convert to lowercase
|
||||
|
||||
# Determine the MIME type
|
||||
mime_types = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'webp': 'image/webp'
|
||||
}
|
||||
mime_type = mime_types.get(file_extension, 'image/jpeg')
|
||||
|
||||
request_data = {
|
||||
"model": modelname,
|
||||
@ -72,12 +90,11 @@ async def predict(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": PROMPT},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/;base64,{img_base64}"},
|
||||
"detail": "high",
|
||||
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
|
||||
},
|
||||
{"type": "text", "text": PROMPT},
|
||||
],
|
||||
}
|
||||
],
|
||||
@ -99,6 +116,7 @@ async def read_root():
|
||||
return {"healthy": True}
|
||||
|
||||
|
||||
@router.post("", include_in_schema=False)
|
||||
@router.post("/")
|
||||
async def vlm(entity: Entity, request: Request):
|
||||
global modelname, endpoint, token
|
||||
|
Loading…
x
Reference in New Issue
Block a user