mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
feat(plugin): add vlm plugin
This commit is contained in:
parent
5f998db989
commit
cead5d9755
@ -1,10 +1,20 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class VLMSettings(BaseModel):
|
||||||
|
enabled: bool = False
|
||||||
|
modelname: str = "internvl-1.5"
|
||||||
|
endpoint: str = "http://localhost:11434"
|
||||||
|
token: str = ""
|
||||||
|
concurrency: int = 8
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model_config = SettingsConfigDict(env_prefix="MEMOS_")
|
model_config = SettingsConfigDict(
|
||||||
|
yaml_file=str(Path.home() / ".memos" / "config.yaml"),
|
||||||
|
yaml_file_encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
base_dir: str = str(Path.home() / ".memos")
|
base_dir: str = str(Path.home() / ".memos")
|
||||||
database_path: str = os.path.join(base_dir, "database.db")
|
database_path: str = os.path.join(base_dir, "database.db")
|
||||||
@ -15,6 +25,8 @@ class Settings(BaseSettings):
|
|||||||
typesense_connection_timeout_seconds: int = 2
|
typesense_connection_timeout_seconds: int = 2
|
||||||
typesense_collection_name: str = "entities"
|
typesense_collection_name: str = "entities"
|
||||||
|
|
||||||
|
# VLM plugin settings
|
||||||
|
vlm: VLMSettings = VLMSettings()
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
@ -24,7 +36,6 @@ os.makedirs(settings.base_dir, exist_ok=True)
|
|||||||
# Global variable for Typesense collection name
|
# Global variable for Typesense collection name
|
||||||
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
|
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
|
||||||
|
|
||||||
|
|
||||||
# Function to get the database path from environment variable or default
|
# Function to get the database path from environment variable or default
|
||||||
def get_database_path():
|
def get_database_path():
|
||||||
return settings.database_path
|
return settings.database_path
|
0
memos/plugins/__init__.py
Normal file
0
memos/plugins/__init__.py
Normal file
0
memos/plugins/vlm/__init__.py
Normal file
0
memos/plugins/vlm/__init__.py
Normal file
202
memos/plugins/vlm/main.py
Normal file
202
memos/plugins/vlm/main.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import base64
|
||||||
|
import httpx
|
||||||
|
from PIL import Image
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
||||||
|
from memos.schemas import Entity, MetadataType
|
||||||
|
import logging
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
PLUGIN_NAME = "vlm"
|
||||||
|
PROMPT = "描述这张图片的内容"
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
tags=[PLUGIN_NAME],
|
||||||
|
responses={404: {"description": "Not found"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
modelname = None
|
||||||
|
endpoint = None
|
||||||
|
token = None
|
||||||
|
concurrency = None
|
||||||
|
semaphore = None
|
||||||
|
|
||||||
|
# Configure logger
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def image2base64(img_path):
|
||||||
|
try:
|
||||||
|
# Attempt to open and verify the image
|
||||||
|
with Image.open(img_path) as img:
|
||||||
|
img.verify() # Verify that it's a valid image file
|
||||||
|
|
||||||
|
# If verification passes, encode the image
|
||||||
|
with open(img_path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
return encoded_string
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
|
||||||
|
async with semaphore: # 使用信号量控制并发
|
||||||
|
response = await client.post(
|
||||||
|
f"{endpoint}/v1/chat/completions",
|
||||||
|
json=request_data,
|
||||||
|
timeout=60,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return None
|
||||||
|
result = response.json()
|
||||||
|
choices = result.get("choices", [])
|
||||||
|
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
|
||||||
|
return choices[0]["message"]["content"]
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
async def predict(
|
||||||
|
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
img_base64 = image2base64(img_path)
|
||||||
|
if not img_base64:
|
||||||
|
return None
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"model": modelname,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": PROMPT},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/;base64,{img_base64}"},
|
||||||
|
"detail": "high",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"temperature": 0.1,
|
||||||
|
"repetition_penalty": 1.1,
|
||||||
|
"top_p": 0.8,
|
||||||
|
}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
headers = {}
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
return await fetch(endpoint, client, request_data, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def read_root():
|
||||||
|
return {"healthy": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/")
|
||||||
|
async def vlm(entity: Entity, request: Request):
|
||||||
|
global modelname, endpoint, token
|
||||||
|
metadata_field_name = f"{modelname.replace('-', '_')}_result"
|
||||||
|
if not entity.file_type_group == "image":
|
||||||
|
return {metadata_field_name: ""}
|
||||||
|
|
||||||
|
# Check if the METADATA_FIELD_NAME field is empty or null
|
||||||
|
existing_metadata = entity.get_metadata_by_key(metadata_field_name)
|
||||||
|
if (
|
||||||
|
existing_metadata
|
||||||
|
and existing_metadata.value
|
||||||
|
and existing_metadata.value.strip()
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Skipping processing for file: {entity.filepath} due to existing metadata"
|
||||||
|
)
|
||||||
|
# If the field is not empty, return without processing
|
||||||
|
return {metadata_field_name: existing_metadata.value}
|
||||||
|
|
||||||
|
# Check if the entity contains the tag "low_info"
|
||||||
|
if any(tag.name == "low_info" for tag in entity.tags):
|
||||||
|
# If the tag is present, return without processing
|
||||||
|
logger.info(
|
||||||
|
f"Skipping processing for file: {entity.filepath} due to 'low_info' tag"
|
||||||
|
)
|
||||||
|
return {metadata_field_name: ""}
|
||||||
|
|
||||||
|
location_url = request.headers.get("Location")
|
||||||
|
if not location_url:
|
||||||
|
raise HTTPException(status_code=400, detail="Location header is missing")
|
||||||
|
|
||||||
|
patch_url = f"{location_url}/metadata"
|
||||||
|
|
||||||
|
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
|
||||||
|
|
||||||
|
print(vlm_result)
|
||||||
|
if not vlm_result:
|
||||||
|
print(f"No VLM result found for file: {entity.filepath}")
|
||||||
|
return {metadata_field_name: "{}"}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.patch(
|
||||||
|
patch_url,
|
||||||
|
json={
|
||||||
|
"metadata_entries": [
|
||||||
|
{
|
||||||
|
"key": metadata_field_name,
|
||||||
|
"value": vlm_result,
|
||||||
|
"source": PLUGIN_NAME,
|
||||||
|
"data_type": MetadataType.TEXT_DATA.value,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status_code, detail="Failed to patch entity metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
metadata_field_name: vlm_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def init_plugin(config):
|
||||||
|
global modelname, endpoint, token, concurrency, semaphore
|
||||||
|
modelname = config.modelname
|
||||||
|
endpoint = config.endpoint
|
||||||
|
token = config.token
|
||||||
|
concurrency = config.concurrency
|
||||||
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
|
||||||
|
parser.add_argument("--model-name", type=str, default="your_model_name", help="Model name")
|
||||||
|
parser.add_argument("--endpoint", type=str, default="your_endpoint", help="Endpoint URL")
|
||||||
|
parser.add_argument("--token", type=str, default="your_token", help="Access token")
|
||||||
|
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
init_plugin(args)
|
||||||
|
|
||||||
|
print(f"Model Name: {args.model_name}")
|
||||||
|
print(f"Endpoint: {args.endpoint}")
|
||||||
|
print(f"Token: {args.token}")
|
||||||
|
print(f"Concurrency: {args.concurrency}")
|
||||||
|
print(f"Port: {args.port}")
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
2
memos/plugins/vlm/requirements.txt
Normal file
2
memos/plugins/vlm/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
httpx
|
||||||
|
fastapi
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import httpx
|
import httpx
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException, Depends, status, Query, Request
|
from fastapi import FastAPI, HTTPException, Depends, status, Query, Request, APIRouter
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.responses import FileResponse, JSONResponse
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
@ -20,6 +20,7 @@ from .read_metadata import read_metadata
|
|||||||
import typesense
|
import typesense
|
||||||
|
|
||||||
from .config import get_database_path, settings
|
from .config import get_database_path, settings
|
||||||
|
from .plugins.vlm import main as vlm_main
|
||||||
from . import crud
|
from . import crud
|
||||||
from . import indexing
|
from . import indexing
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
@ -39,7 +40,7 @@ from .schemas import (
|
|||||||
EntityIndexItem,
|
EntityIndexItem,
|
||||||
MetadataIndexItem,
|
MetadataIndexItem,
|
||||||
EntitySearchResult,
|
EntitySearchResult,
|
||||||
SearchResult
|
SearchResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import the logging configuration
|
# Import the logging configuration
|
||||||
@ -446,7 +447,9 @@ async def search_entities(
|
|||||||
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
|
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
|
||||||
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
|
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
|
||||||
tags: str = Query(None, description="Comma-separated list of tags"),
|
tags: str = Query(None, description="Comma-separated list of tags"),
|
||||||
created_dates: str = Query(None, description="Comma-separated list of created dates in YYYY-MM-DD format"),
|
created_dates: str = Query(
|
||||||
|
None, description="Comma-separated list of created dates in YYYY-MM-DD format"
|
||||||
|
),
|
||||||
limit: Annotated[int, Query(ge=1, le=200)] = 48,
|
limit: Annotated[int, Query(ge=1, le=200)] = 48,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
start: int = None,
|
start: int = None,
|
||||||
@ -456,10 +459,21 @@ async def search_entities(
|
|||||||
library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None
|
library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None
|
||||||
folder_ids = [int(id) for id in folder_ids.split(",")] if folder_ids else None
|
folder_ids = [int(id) for id in folder_ids.split(",")] if folder_ids else None
|
||||||
tags = [tag.strip() for tag in tags.split(",")] if tags else None
|
tags = [tag.strip() for tag in tags.split(",")] if tags else None
|
||||||
created_dates = [date.strip() for date in created_dates.split(",")] if created_dates else None
|
created_dates = (
|
||||||
|
[date.strip() for date in created_dates.split(",")] if created_dates else None
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
return indexing.search_entities(
|
return indexing.search_entities(
|
||||||
client, q, library_ids, folder_ids, tags, created_dates, limit, offset, start, end
|
client,
|
||||||
|
q,
|
||||||
|
library_ids,
|
||||||
|
folder_ids,
|
||||||
|
tags,
|
||||||
|
created_dates,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error searching entities: {e}")
|
print(f"Error searching entities: {e}")
|
||||||
@ -575,7 +589,7 @@ def is_image(file_path: Path) -> bool:
|
|||||||
def get_thumbnail_info(metadata: dict) -> tuple:
|
def get_thumbnail_info(metadata: dict) -> tuple:
|
||||||
if not metadata:
|
if not metadata:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
if not metadata.get("sequence"):
|
if not metadata.get("sequence"):
|
||||||
return None, None, False
|
return None, None, False
|
||||||
|
|
||||||
@ -643,11 +657,20 @@ async def get_file(file_path: str):
|
|||||||
raise HTTPException(status_code=404, detail="File not found")
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
|
||||||
|
# Add VLM plugin router
|
||||||
|
if settings.vlm.enabled:
|
||||||
|
print("VLM plugin is enabled")
|
||||||
|
vlm_main.init_plugin(settings.vlm)
|
||||||
|
app.include_router(vlm_main.router, prefix=f"/plugins/{vlm_main.PLUGIN_NAME}")
|
||||||
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
print("Database path:", get_database_path())
|
print("Database path:", get_database_path())
|
||||||
print(
|
print(
|
||||||
f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}"
|
f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}"
|
||||||
)
|
)
|
||||||
|
print(f"VLM plugin enabled: {settings.vlm}")
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"memos.server:app",
|
"memos.server:app",
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user