mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-09 04:35:26 +00:00
fix(ollama): make vlm plugin support ollama
This commit is contained in:
parent
67a5e10d3e
commit
b10b080800
@ -7,6 +7,8 @@ from fastapi import APIRouter, FastAPI, Request, HTTPException
|
|||||||
from memos.schemas import Entity, MetadataType
|
from memos.schemas import Entity, MetadataType
|
||||||
import logging
|
import logging
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
PLUGIN_NAME = "vlm"
|
PLUGIN_NAME = "vlm"
|
||||||
PROMPT = "描述这张图片的内容"
|
PROMPT = "描述这张图片的内容"
|
||||||
@ -44,19 +46,22 @@ def image2base64(img_path):
|
|||||||
|
|
||||||
async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
|
async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
|
||||||
async with semaphore: # 使用信号量控制并发
|
async with semaphore: # 使用信号量控制并发
|
||||||
response = await client.post(
|
try:
|
||||||
f"{endpoint}/v1/chat/completions",
|
response = await client.post(
|
||||||
json=request_data,
|
f"{endpoint}/v1/chat/completions",
|
||||||
timeout=60,
|
json=request_data,
|
||||||
headers=headers,
|
timeout=60,
|
||||||
)
|
headers=headers,
|
||||||
if response.status_code != 200:
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
choices = result.get("choices", [])
|
||||||
|
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
|
||||||
|
return choices[0]["message"]["content"]
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception occurred: {str(e)}")
|
||||||
return None
|
return None
|
||||||
result = response.json()
|
|
||||||
choices = result.get("choices", [])
|
|
||||||
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
|
|
||||||
return choices[0]["message"]["content"]
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
async def predict(
|
async def predict(
|
||||||
@ -66,18 +71,30 @@ async def predict(
|
|||||||
if not img_base64:
|
if not img_base64:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Get the file extension
|
||||||
|
_, file_extension = os.path.splitext(img_path)
|
||||||
|
file_extension = file_extension.lower()[1:] # Remove the dot and convert to lowercase
|
||||||
|
|
||||||
|
# Determine the MIME type
|
||||||
|
mime_types = {
|
||||||
|
'png': 'image/png',
|
||||||
|
'jpg': 'image/jpeg',
|
||||||
|
'jpeg': 'image/jpeg',
|
||||||
|
'webp': 'image/webp'
|
||||||
|
}
|
||||||
|
mime_type = mime_types.get(file_extension, 'image/jpeg')
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"model": modelname,
|
"model": modelname,
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": PROMPT},
|
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": f"data:image/;base64,{img_base64}"},
|
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
|
||||||
"detail": "high",
|
|
||||||
},
|
},
|
||||||
|
{"type": "text", "text": PROMPT},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -99,6 +116,7 @@ async def read_root():
|
|||||||
return {"healthy": True}
|
return {"healthy": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", include_in_schema=False)
|
||||||
@router.post("/")
|
@router.post("/")
|
||||||
async def vlm(entity: Entity, request: Request):
|
async def vlm(entity: Entity, request: Request):
|
||||||
global modelname, endpoint, token
|
global modelname, endpoint, token
|
||||||
|
Loading…
x
Reference in New Issue
Block a user