From cead5d9755ee33c8814671af52f6a6925a2ec359 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Sat, 24 Aug 2024 00:20:07 +0800 Subject: [PATCH] feat(plugin): add vlm plugin --- memos/config.py | 15 ++- memos/plugins/__init__.py | 0 memos/plugins/vlm/__init__.py | 0 memos/plugins/vlm/main.py | 202 +++++++++++++++++++++++++++++ memos/plugins/vlm/requirements.txt | 2 + memos/server.py | 35 ++++- 6 files changed, 246 insertions(+), 8 deletions(-) create mode 100644 memos/plugins/__init__.py create mode 100644 memos/plugins/vlm/__init__.py create mode 100644 memos/plugins/vlm/main.py create mode 100644 memos/plugins/vlm/requirements.txt diff --git a/memos/config.py b/memos/config.py index 9e3d24d..b160337 100644 --- a/memos/config.py +++ b/memos/config.py @@ -1,10 +1,20 @@ import os from pathlib import Path from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic import BaseModel +class VLMSettings(BaseModel): + enabled: bool = False + modelname: str = "internvl-1.5" + endpoint: str = "http://localhost:11434" + token: str = "" + concurrency: int = 8 class Settings(BaseSettings): - model_config = SettingsConfigDict(env_prefix="MEMOS_") + model_config = SettingsConfigDict( + yaml_file=str(Path.home() / ".memos" / "config.yaml"), + yaml_file_encoding="utf-8" + ) base_dir: str = str(Path.home() / ".memos") database_path: str = os.path.join(base_dir, "database.db") @@ -15,6 +25,8 @@ class Settings(BaseSettings): typesense_connection_timeout_seconds: int = 2 typesense_collection_name: str = "entities" + # VLM plugin settings + vlm: VLMSettings = VLMSettings() settings = Settings() @@ -24,7 +36,6 @@ os.makedirs(settings.base_dir, exist_ok=True) # Global variable for Typesense collection name TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name - # Function to get the database path from environment variable or default def get_database_path(): return settings.database_path \ No newline at end of file diff --git a/memos/plugins/__init__.py b/memos/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/memos/plugins/vlm/__init__.py b/memos/plugins/vlm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py new file mode 100644 index 0000000..3276ea0 --- /dev/null +++ b/memos/plugins/vlm/main.py @@ -0,0 +1,202 @@ +import base64 +import httpx +from PIL import Image +import asyncio +from typing import Optional +from fastapi import APIRouter, FastAPI, Request, HTTPException +from memos.schemas import Entity, MetadataType +import logging +import uvicorn + +PLUGIN_NAME = "vlm" +PROMPT = "描述这张图片的内容" + +router = APIRouter( + tags=[PLUGIN_NAME], + responses={404: {"description": "Not found"}} +) + +modelname = None +endpoint = None +token = None +concurrency = None +semaphore = None + +# Configure logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def image2base64(img_path): + try: + # Attempt to open and verify the image + with Image.open(img_path) as img: + img.verify() # Verify that it's a valid image file + + # If verification passes, encode the image + with open(img_path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + return encoded_string + except Exception as e: + logger.error(f"Error processing image {img_path}: {str(e)}") + return None + + +async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None): + async with semaphore: # 使用信号量控制并发 + response = await client.post( + f"{endpoint}/v1/chat/completions", + json=request_data, + timeout=60, + headers=headers, + ) + if response.status_code != 200: + return None + result = response.json() + choices = result.get("choices", []) + if choices and "message" in choices[0] and "content" in choices[0]["message"]: + return choices[0]["message"]["content"] + return "" + + +async def predict( + endpoint: str, modelname: str, img_path: str, token: Optional[str] = None +) -> Optional[str]: + img_base64 = image2base64(img_path) + if not img_base64: + return None + + request_data = { + "model": modelname, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": PROMPT}, + { + "type": "image_url", + "image_url": {"url": f"data:image/;base64,{img_base64}"}, + "detail": "high", + }, + ], + } + ], + "stream": False, + "max_tokens": 1024, + "temperature": 0.1, + "repetition_penalty": 1.1, + "top_p": 0.8, + } + async with httpx.AsyncClient() as client: + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + return await fetch(endpoint, client, request_data, headers=headers) + + +@router.get("/") +async def read_root(): + return {"healthy": True} + + +@router.post("/") +async def vlm(entity: Entity, request: Request): + global modelname, endpoint, token + metadata_field_name = f"{modelname.replace('-', '_')}_result" + if not entity.file_type_group == "image": + return {metadata_field_name: ""} + + # Check if the METADATA_FIELD_NAME field is empty or null + existing_metadata = entity.get_metadata_by_key(metadata_field_name) + if ( + existing_metadata + and existing_metadata.value + and existing_metadata.value.strip() + ): + logger.info( + f"Skipping processing for file: {entity.filepath} due to existing metadata" + ) + # If the field is not empty, return without processing + return {metadata_field_name: existing_metadata.value} + + # Check if the entity contains the tag "low_info" + if any(tag.name == "low_info" for tag in entity.tags): + # If the tag is present, return without processing + logger.info( + f"Skipping processing for file: {entity.filepath} due to 'low_info' tag" + ) + return {metadata_field_name: ""} + + location_url = request.headers.get("Location") + if not location_url: + raise HTTPException(status_code=400, detail="Location header is missing") + + patch_url = f"{location_url}/metadata" + + vlm_result = await predict(endpoint, modelname, entity.filepath, token=token) + + print(vlm_result) + if not vlm_result: + print(f"No VLM result found for file: {entity.filepath}") + return {metadata_field_name: "{}"} + + async with httpx.AsyncClient() as client: + response = await client.patch( + patch_url, + json={ + "metadata_entries": [ + { + "key": metadata_field_name, + "value": vlm_result, + "source": PLUGIN_NAME, + "data_type": MetadataType.TEXT_DATA.value, + } + ] + }, + timeout=30, + ) + + if response.status_code != 200: + raise HTTPException( + status_code=response.status_code, detail="Failed to patch entity metadata" + ) + + return { + metadata_field_name: vlm_result, + } + + +def init_plugin(config): + global modelname, endpoint, token, concurrency, semaphore + modelname = config.modelname + endpoint = config.endpoint + token = config.token + concurrency = config.concurrency + semaphore = asyncio.Semaphore(concurrency) + + +if __name__ == "__main__": + import argparse + from fastapi import FastAPI + + parser = argparse.ArgumentParser(description="VLM Plugin Configuration") + parser.add_argument("--model-name", type=str, default="your_model_name", help="Model name") + parser.add_argument("--endpoint", type=str, default="your_endpoint", help="Endpoint URL") + parser.add_argument("--token", type=str, default="your_token", help="Access token") + parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level") + parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") + + args = parser.parse_args() + + init_plugin(args) + + print(f"Model Name: {args.model_name}") + print(f"Endpoint: {args.endpoint}") + print(f"Token: {args.token}") + print(f"Concurrency: {args.concurrency}") + print(f"Port: {args.port}") + + app = FastAPI() + app.include_router(router) + + uvicorn.run(app, host="0.0.0.0", port=args.port) \ No newline at end of file diff --git a/memos/plugins/vlm/requirements.txt b/memos/plugins/vlm/requirements.txt new file mode 100644 index 0000000..d5da81b --- /dev/null +++ b/memos/plugins/vlm/requirements.txt @@ -0,0 +1,2 @@ +httpx +fastapi diff --git a/memos/server.py b/memos/server.py index a74bbd7..f361eeb 100644 --- a/memos/server.py +++ b/memos/server.py @@ -1,7 +1,7 @@ import os import httpx import uvicorn -from fastapi import FastAPI, HTTPException, Depends, status, Query, Request +from fastapi import FastAPI, HTTPException, Depends, status, Query, Request, APIRouter from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, JSONResponse @@ -20,6 +20,7 @@ from .read_metadata import read_metadata import typesense from .config import get_database_path, settings +from .plugins.vlm import main as vlm_main from . import crud from . import indexing from .schemas import ( @@ -39,7 +40,7 @@ from .schemas import ( EntityIndexItem, MetadataIndexItem, EntitySearchResult, - SearchResult + SearchResult, ) # Import the logging configuration @@ -446,7 +447,9 @@ async def search_entities( library_ids: str = Query(None, description="Comma-separated list of library IDs"), folder_ids: str = Query(None, description="Comma-separated list of folder IDs"), tags: str = Query(None, description="Comma-separated list of tags"), - created_dates: str = Query(None, description="Comma-separated list of created dates in YYYY-MM-DD format"), + created_dates: str = Query( + None, description="Comma-separated list of created dates in YYYY-MM-DD format" + ), limit: Annotated[int, Query(ge=1, le=200)] = 48, offset: int = 0, start: int = None, @@ -456,10 +459,21 @@ async def search_entities( library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None folder_ids = [int(id) for id in folder_ids.split(",")] if folder_ids else None tags = [tag.strip() for tag in tags.split(",")] if tags else None - created_dates = [date.strip() for date in created_dates.split(",")] if created_dates else None + created_dates = ( + [date.strip() for date in created_dates.split(",")] if created_dates else None + ) try: return indexing.search_entities( - client, q, library_ids, folder_ids, tags, created_dates, limit, offset, start, end + client, + q, + library_ids, + folder_ids, + tags, + created_dates, + limit, + offset, + start, + end, ) except Exception as e: print(f"Error searching entities: {e}") @@ -575,7 +589,7 @@ def is_image(file_path: Path) -> bool: def get_thumbnail_info(metadata: dict) -> tuple: if not metadata: return None, None, None - + if not metadata.get("sequence"): return None, None, False @@ -643,11 +657,20 @@ async def get_file(file_path: str): raise HTTPException(status_code=404, detail="File not found") +# Add VLM plugin router +if settings.vlm.enabled: + print("VLM plugin is enabled") + vlm_main.init_plugin(settings.vlm) + app.include_router(vlm_main.router, prefix=f"/plugins/{vlm_main.PLUGIN_NAME}") + + def run_server(): print("Database path:", get_database_path()) print( f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}" ) + print(f"VLM plugin enabled: {settings.vlm}") + uvicorn.run( "memos.server:app", host="0.0.0.0",