mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(plugin): add vlm plugin
This commit is contained in:
parent
5f998db989
commit
cead5d9755
@ -1,10 +1,20 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
class VLMSettings(BaseModel):
|
||||
enabled: bool = False
|
||||
modelname: str = "internvl-1.5"
|
||||
endpoint: str = "http://localhost:11434"
|
||||
token: str = ""
|
||||
concurrency: int = 8
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="MEMOS_")
|
||||
model_config = SettingsConfigDict(
|
||||
yaml_file=str(Path.home() / ".memos" / "config.yaml"),
|
||||
yaml_file_encoding="utf-8"
|
||||
)
|
||||
|
||||
base_dir: str = str(Path.home() / ".memos")
|
||||
database_path: str = os.path.join(base_dir, "database.db")
|
||||
@ -15,6 +25,8 @@ class Settings(BaseSettings):
|
||||
typesense_connection_timeout_seconds: int = 2
|
||||
typesense_collection_name: str = "entities"
|
||||
|
||||
# VLM plugin settings
|
||||
vlm: VLMSettings = VLMSettings()
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@ -24,7 +36,6 @@ os.makedirs(settings.base_dir, exist_ok=True)
|
||||
# Global variable for Typesense collection name
|
||||
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
|
||||
|
||||
|
||||
# Function to get the database path from environment variable or default
|
||||
def get_database_path():
|
||||
return settings.database_path
|
0
memos/plugins/__init__.py
Normal file
0
memos/plugins/__init__.py
Normal file
0
memos/plugins/vlm/__init__.py
Normal file
0
memos/plugins/vlm/__init__.py
Normal file
202
memos/plugins/vlm/main.py
Normal file
202
memos/plugins/vlm/main.py
Normal file
@ -0,0 +1,202 @@
|
||||
import base64
|
||||
import httpx
|
||||
from PIL import Image
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
||||
from memos.schemas import Entity, MetadataType
|
||||
import logging
|
||||
import uvicorn
|
||||
|
||||
PLUGIN_NAME = "vlm"
|
||||
PROMPT = "描述这张图片的内容"
|
||||
|
||||
router = APIRouter(
|
||||
tags=[PLUGIN_NAME],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
modelname = None
|
||||
endpoint = None
|
||||
token = None
|
||||
concurrency = None
|
||||
semaphore = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def image2base64(img_path):
|
||||
try:
|
||||
# Attempt to open and verify the image
|
||||
with Image.open(img_path) as img:
|
||||
img.verify() # Verify that it's a valid image file
|
||||
|
||||
# If verification passes, encode the image
|
||||
with open(img_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
return encoded_string
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
|
||||
async with semaphore: # 使用信号量控制并发
|
||||
response = await client.post(
|
||||
f"{endpoint}/v1/chat/completions",
|
||||
json=request_data,
|
||||
timeout=60,
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
result = response.json()
|
||||
choices = result.get("choices", [])
|
||||
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
|
||||
return choices[0]["message"]["content"]
|
||||
return ""
|
||||
|
||||
|
||||
async def predict(
|
||||
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
img_base64 = image2base64(img_path)
|
||||
if not img_base64:
|
||||
return None
|
||||
|
||||
request_data = {
|
||||
"model": modelname,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": PROMPT},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/;base64,{img_base64}"},
|
||||
"detail": "high",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.1,
|
||||
"repetition_penalty": 1.1,
|
||||
"top_p": 0.8,
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
return await fetch(endpoint, client, request_data, headers=headers)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def read_root():
|
||||
return {"healthy": True}
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def vlm(entity: Entity, request: Request):
|
||||
global modelname, endpoint, token
|
||||
metadata_field_name = f"{modelname.replace('-', '_')}_result"
|
||||
if not entity.file_type_group == "image":
|
||||
return {metadata_field_name: ""}
|
||||
|
||||
# Check if the METADATA_FIELD_NAME field is empty or null
|
||||
existing_metadata = entity.get_metadata_by_key(metadata_field_name)
|
||||
if (
|
||||
existing_metadata
|
||||
and existing_metadata.value
|
||||
and existing_metadata.value.strip()
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping processing for file: {entity.filepath} due to existing metadata"
|
||||
)
|
||||
# If the field is not empty, return without processing
|
||||
return {metadata_field_name: existing_metadata.value}
|
||||
|
||||
# Check if the entity contains the tag "low_info"
|
||||
if any(tag.name == "low_info" for tag in entity.tags):
|
||||
# If the tag is present, return without processing
|
||||
logger.info(
|
||||
f"Skipping processing for file: {entity.filepath} due to 'low_info' tag"
|
||||
)
|
||||
return {metadata_field_name: ""}
|
||||
|
||||
location_url = request.headers.get("Location")
|
||||
if not location_url:
|
||||
raise HTTPException(status_code=400, detail="Location header is missing")
|
||||
|
||||
patch_url = f"{location_url}/metadata"
|
||||
|
||||
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
|
||||
|
||||
print(vlm_result)
|
||||
if not vlm_result:
|
||||
print(f"No VLM result found for file: {entity.filepath}")
|
||||
return {metadata_field_name: "{}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.patch(
|
||||
patch_url,
|
||||
json={
|
||||
"metadata_entries": [
|
||||
{
|
||||
"key": metadata_field_name,
|
||||
"value": vlm_result,
|
||||
"source": PLUGIN_NAME,
|
||||
"data_type": MetadataType.TEXT_DATA.value,
|
||||
}
|
||||
]
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code, detail="Failed to patch entity metadata"
|
||||
)
|
||||
|
||||
return {
|
||||
metadata_field_name: vlm_result,
|
||||
}
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global modelname, endpoint, token, concurrency, semaphore
|
||||
modelname = config.modelname
|
||||
endpoint = config.endpoint
|
||||
token = config.token
|
||||
concurrency = config.concurrency
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from fastapi import FastAPI
|
||||
|
||||
parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
|
||||
parser.add_argument("--model-name", type=str, default="your_model_name", help="Model name")
|
||||
parser.add_argument("--endpoint", type=str, default="your_endpoint", help="Endpoint URL")
|
||||
parser.add_argument("--token", type=str, default="your_token", help="Access token")
|
||||
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
init_plugin(args)
|
||||
|
||||
print(f"Model Name: {args.model_name}")
|
||||
print(f"Endpoint: {args.endpoint}")
|
||||
print(f"Token: {args.token}")
|
||||
print(f"Concurrency: {args.concurrency}")
|
||||
print(f"Port: {args.port}")
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
2
memos/plugins/vlm/requirements.txt
Normal file
2
memos/plugins/vlm/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
httpx
|
||||
fastapi
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Depends, status, Query, Request
|
||||
from fastapi import FastAPI, HTTPException, Depends, status, Query, Request, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
@ -20,6 +20,7 @@ from .read_metadata import read_metadata
|
||||
import typesense
|
||||
|
||||
from .config import get_database_path, settings
|
||||
from .plugins.vlm import main as vlm_main
|
||||
from . import crud
|
||||
from . import indexing
|
||||
from .schemas import (
|
||||
@ -39,7 +40,7 @@ from .schemas import (
|
||||
EntityIndexItem,
|
||||
MetadataIndexItem,
|
||||
EntitySearchResult,
|
||||
SearchResult
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
# Import the logging configuration
|
||||
@ -446,7 +447,9 @@ async def search_entities(
|
||||
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
|
||||
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
|
||||
tags: str = Query(None, description="Comma-separated list of tags"),
|
||||
created_dates: str = Query(None, description="Comma-separated list of created dates in YYYY-MM-DD format"),
|
||||
created_dates: str = Query(
|
||||
None, description="Comma-separated list of created dates in YYYY-MM-DD format"
|
||||
),
|
||||
limit: Annotated[int, Query(ge=1, le=200)] = 48,
|
||||
offset: int = 0,
|
||||
start: int = None,
|
||||
@ -456,10 +459,21 @@ async def search_entities(
|
||||
library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None
|
||||
folder_ids = [int(id) for id in folder_ids.split(",")] if folder_ids else None
|
||||
tags = [tag.strip() for tag in tags.split(",")] if tags else None
|
||||
created_dates = [date.strip() for date in created_dates.split(",")] if created_dates else None
|
||||
created_dates = (
|
||||
[date.strip() for date in created_dates.split(",")] if created_dates else None
|
||||
)
|
||||
try:
|
||||
return indexing.search_entities(
|
||||
client, q, library_ids, folder_ids, tags, created_dates, limit, offset, start, end
|
||||
client,
|
||||
q,
|
||||
library_ids,
|
||||
folder_ids,
|
||||
tags,
|
||||
created_dates,
|
||||
limit,
|
||||
offset,
|
||||
start,
|
||||
end,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error searching entities: {e}")
|
||||
@ -575,7 +589,7 @@ def is_image(file_path: Path) -> bool:
|
||||
def get_thumbnail_info(metadata: dict) -> tuple:
|
||||
if not metadata:
|
||||
return None, None, None
|
||||
|
||||
|
||||
if not metadata.get("sequence"):
|
||||
return None, None, False
|
||||
|
||||
@ -643,11 +657,20 @@ async def get_file(file_path: str):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
|
||||
# Add VLM plugin router
|
||||
if settings.vlm.enabled:
|
||||
print("VLM plugin is enabled")
|
||||
vlm_main.init_plugin(settings.vlm)
|
||||
app.include_router(vlm_main.router, prefix=f"/plugins/{vlm_main.PLUGIN_NAME}")
|
||||
|
||||
|
||||
def run_server():
|
||||
print("Database path:", get_database_path())
|
||||
print(
|
||||
f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}"
|
||||
)
|
||||
print(f"VLM plugin enabled: {settings.vlm}")
|
||||
|
||||
uvicorn.run(
|
||||
"memos.server:app",
|
||||
host="0.0.0.0",
|
||||
|
Loading…
x
Reference in New Issue
Block a user