2024-09-01 19:01:49 +08:00

237 lines
7.0 KiB
Python

import base64
import httpx
from PIL import Image
import asyncio
from typing import Optional
from fastapi import APIRouter, FastAPI, Request, HTTPException
from memos.schemas import Entity, MetadataType
import logging
import uvicorn
import os
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
modelname = None
endpoint = None
token = None
concurrency = None
semaphore = None
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def image2base64(img_path):
try:
# Attempt to open and verify the image
with Image.open(img_path) as img:
img.verify() # Verify that it's a valid image file
# If verification passes, encode the image
with open(img_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return encoded_string
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
return None
async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
async with semaphore: # 使用信号量控制并发
try:
response = await client.post(
f"{endpoint}/v1/chat/completions",
json=request_data,
timeout=60,
headers=headers,
)
response.raise_for_status()
result = response.json()
choices = result.get("choices", [])
if (
choices
and "message" in choices[0]
and "content" in choices[0]["message"]
):
return choices[0]["message"]["content"]
return ""
except Exception as e:
logger.error(f"Exception occurred: {str(e)}")
return None
async def predict(
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
) -> Optional[str]:
img_base64 = image2base64(img_path)
if not img_base64:
return None
# Get the file extension
_, file_extension = os.path.splitext(img_path)
file_extension = file_extension.lower()[
1:
] # Remove the dot and convert to lowercase
# Determine the MIME type
mime_types = {
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"webp": "image/webp",
}
mime_type = mime_types.get(file_extension, "image/jpeg")
request_data = {
"model": modelname,
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
},
{"type": "text", "text": PROMPT},
],
}
],
"stream": False,
"max_tokens": 1024,
"temperature": 0.1,
"repetition_penalty": 1.1,
"top_p": 0.8,
}
async with httpx.AsyncClient() as client:
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
return await fetch(endpoint, client, request_data, headers=headers)
@router.get("/")
async def read_root():
return {"healthy": True}
@router.post("", include_in_schema=False)
@router.post("/")
async def vlm(entity: Entity, request: Request):
global modelname, endpoint, token
metadata_field_name = f"{modelname.replace('-', '_')}_result"
if not entity.file_type_group == "image":
return {metadata_field_name: ""}
# Check if the METADATA_FIELD_NAME field is empty or null
existing_metadata = entity.get_metadata_by_key(metadata_field_name)
if (
existing_metadata
and existing_metadata.value
and existing_metadata.value.strip()
):
logger.info(
f"Skipping processing for file: {entity.filepath} due to existing metadata"
)
# If the field is not empty, return without processing
return {metadata_field_name: existing_metadata.value}
# Check if the entity contains the tag "low_info"
if any(tag.name == "low_info" for tag in entity.tags):
# If the tag is present, return without processing
logger.info(
f"Skipping processing for file: {entity.filepath} due to 'low_info' tag"
)
return {metadata_field_name: ""}
location_url = request.headers.get("Location")
if not location_url:
raise HTTPException(status_code=400, detail="Location header is missing")
patch_url = f"{location_url}/metadata"
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
print(vlm_result)
if not vlm_result:
print(f"No VLM result found for file: {entity.filepath}")
return {metadata_field_name: "{}"}
async with httpx.AsyncClient() as client:
response = await client.patch(
patch_url,
json={
"metadata_entries": [
{
"key": metadata_field_name,
"value": vlm_result,
"source": PLUGIN_NAME,
"data_type": MetadataType.TEXT_DATA.value,
}
]
},
timeout=30,
)
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code, detail="Failed to patch entity metadata"
)
return {
metadata_field_name: vlm_result,
}
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore
modelname = config.modelname
endpoint = config.endpoint
token = config.token
concurrency = config.concurrency
semaphore = asyncio.Semaphore(concurrency)
# Print the parameters
logger.info("VLM plugin initialized")
logger.info(f"Model Name: {modelname}")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}")
if __name__ == "__main__":
import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
parser.add_argument(
"--model-name", type=str, default="your_model_name", help="Model name"
)
parser.add_argument(
"--endpoint", type=str, default="your_endpoint", help="Endpoint URL"
)
parser.add_argument("--token", type=str, default="your_token", help="Access token")
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args()
init_plugin(args)
print(f"Model Name: {args.model_name}")
print(f"Endpoint: {args.endpoint}")
print(f"Token: {args.token}")
print(f"Concurrency: {args.concurrency}")
print(f"Port: {args.port}")
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)