feat(plugin): add vlm plugin

This commit is contained in:
arkohut 2024-08-24 00:20:07 +08:00
parent 5f998db989
commit cead5d9755
6 changed files with 246 additions and 8 deletions

View File

@ -1,10 +1,20 @@
import os
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import BaseModel
class VLMSettings(BaseModel):
enabled: bool = False
modelname: str = "internvl-1.5"
endpoint: str = "http://localhost:11434"
token: str = ""
concurrency: int = 8
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="MEMOS_")
model_config = SettingsConfigDict(
yaml_file=str(Path.home() / ".memos" / "config.yaml"),
yaml_file_encoding="utf-8"
)
base_dir: str = str(Path.home() / ".memos")
database_path: str = os.path.join(base_dir, "database.db")
@ -15,6 +25,8 @@ class Settings(BaseSettings):
typesense_connection_timeout_seconds: int = 2
typesense_collection_name: str = "entities"
# VLM plugin settings
vlm: VLMSettings = VLMSettings()
settings = Settings()
@ -24,7 +36,6 @@ os.makedirs(settings.base_dir, exist_ok=True)
# Global variable for Typesense collection name
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
# Function to get the database path from environment variable or default
def get_database_path():
return settings.database_path

View File

View File

202
memos/plugins/vlm/main.py Normal file
View File

@ -0,0 +1,202 @@
import base64
import httpx
from PIL import Image
import asyncio
from typing import Optional
from fastapi import APIRouter, FastAPI, Request, HTTPException
from memos.schemas import Entity, MetadataType
import logging
import uvicorn
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
router = APIRouter(
tags=[PLUGIN_NAME],
responses={404: {"description": "Not found"}}
)
modelname = None
endpoint = None
token = None
concurrency = None
semaphore = None
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def image2base64(img_path):
try:
# Attempt to open and verify the image
with Image.open(img_path) as img:
img.verify() # Verify that it's a valid image file
# If verification passes, encode the image
with open(img_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return encoded_string
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
return None
async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
async with semaphore: # 使用信号量控制并发
response = await client.post(
f"{endpoint}/v1/chat/completions",
json=request_data,
timeout=60,
headers=headers,
)
if response.status_code != 200:
return None
result = response.json()
choices = result.get("choices", [])
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
return choices[0]["message"]["content"]
return ""
async def predict(
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
) -> Optional[str]:
img_base64 = image2base64(img_path)
if not img_base64:
return None
request_data = {
"model": modelname,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": PROMPT},
{
"type": "image_url",
"image_url": {"url": f"data:image/;base64,{img_base64}"},
"detail": "high",
},
],
}
],
"stream": False,
"max_tokens": 1024,
"temperature": 0.1,
"repetition_penalty": 1.1,
"top_p": 0.8,
}
async with httpx.AsyncClient() as client:
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
return await fetch(endpoint, client, request_data, headers=headers)
@router.get("/")
async def read_root():
return {"healthy": True}
@router.post("/")
async def vlm(entity: Entity, request: Request):
global modelname, endpoint, token
metadata_field_name = f"{modelname.replace('-', '_')}_result"
if not entity.file_type_group == "image":
return {metadata_field_name: ""}
# Check if the METADATA_FIELD_NAME field is empty or null
existing_metadata = entity.get_metadata_by_key(metadata_field_name)
if (
existing_metadata
and existing_metadata.value
and existing_metadata.value.strip()
):
logger.info(
f"Skipping processing for file: {entity.filepath} due to existing metadata"
)
# If the field is not empty, return without processing
return {metadata_field_name: existing_metadata.value}
# Check if the entity contains the tag "low_info"
if any(tag.name == "low_info" for tag in entity.tags):
# If the tag is present, return without processing
logger.info(
f"Skipping processing for file: {entity.filepath} due to 'low_info' tag"
)
return {metadata_field_name: ""}
location_url = request.headers.get("Location")
if not location_url:
raise HTTPException(status_code=400, detail="Location header is missing")
patch_url = f"{location_url}/metadata"
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
print(vlm_result)
if not vlm_result:
print(f"No VLM result found for file: {entity.filepath}")
return {metadata_field_name: "{}"}
async with httpx.AsyncClient() as client:
response = await client.patch(
patch_url,
json={
"metadata_entries": [
{
"key": metadata_field_name,
"value": vlm_result,
"source": PLUGIN_NAME,
"data_type": MetadataType.TEXT_DATA.value,
}
]
},
timeout=30,
)
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code, detail="Failed to patch entity metadata"
)
return {
metadata_field_name: vlm_result,
}
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore
modelname = config.modelname
endpoint = config.endpoint
token = config.token
concurrency = config.concurrency
semaphore = asyncio.Semaphore(concurrency)
if __name__ == "__main__":
import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
parser.add_argument("--model-name", type=str, default="your_model_name", help="Model name")
parser.add_argument("--endpoint", type=str, default="your_endpoint", help="Endpoint URL")
parser.add_argument("--token", type=str, default="your_token", help="Access token")
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
args = parser.parse_args()
init_plugin(args)
print(f"Model Name: {args.model_name}")
print(f"Endpoint: {args.endpoint}")
print(f"Token: {args.token}")
print(f"Concurrency: {args.concurrency}")
print(f"Port: {args.port}")
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -0,0 +1,2 @@
httpx
fastapi

View File

@ -1,7 +1,7 @@
import os
import httpx
import uvicorn
from fastapi import FastAPI, HTTPException, Depends, status, Query, Request
from fastapi import FastAPI, HTTPException, Depends, status, Query, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
@ -20,6 +20,7 @@ from .read_metadata import read_metadata
import typesense
from .config import get_database_path, settings
from .plugins.vlm import main as vlm_main
from . import crud
from . import indexing
from .schemas import (
@ -39,7 +40,7 @@ from .schemas import (
EntityIndexItem,
MetadataIndexItem,
EntitySearchResult,
SearchResult
SearchResult,
)
# Import the logging configuration
@ -446,7 +447,9 @@ async def search_entities(
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
tags: str = Query(None, description="Comma-separated list of tags"),
created_dates: str = Query(None, description="Comma-separated list of created dates in YYYY-MM-DD format"),
created_dates: str = Query(
None, description="Comma-separated list of created dates in YYYY-MM-DD format"
),
limit: Annotated[int, Query(ge=1, le=200)] = 48,
offset: int = 0,
start: int = None,
@ -456,10 +459,21 @@ async def search_entities(
library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None
folder_ids = [int(id) for id in folder_ids.split(",")] if folder_ids else None
tags = [tag.strip() for tag in tags.split(",")] if tags else None
created_dates = [date.strip() for date in created_dates.split(",")] if created_dates else None
created_dates = (
[date.strip() for date in created_dates.split(",")] if created_dates else None
)
try:
return indexing.search_entities(
client, q, library_ids, folder_ids, tags, created_dates, limit, offset, start, end
client,
q,
library_ids,
folder_ids,
tags,
created_dates,
limit,
offset,
start,
end,
)
except Exception as e:
print(f"Error searching entities: {e}")
@ -575,7 +589,7 @@ def is_image(file_path: Path) -> bool:
def get_thumbnail_info(metadata: dict) -> tuple:
if not metadata:
return None, None, None
if not metadata.get("sequence"):
return None, None, False
@ -643,11 +657,20 @@ async def get_file(file_path: str):
raise HTTPException(status_code=404, detail="File not found")
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix=f"/plugins/{vlm_main.PLUGIN_NAME}")
def run_server():
print("Database path:", get_database_path())
print(
f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}"
)
print(f"VLM plugin enabled: {settings.vlm}")
uvicorn.run(
"memos.server:app",
host="0.0.0.0",