From 86a73ca72eba97354b932b5058b4537ce3893082 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:07:09 +0800 Subject: [PATCH] feat(vlm): make prompt a config --- memos/config.py | 25 +++++++++++++++++-------- memos/plugins/vlm/main.py | 8 +++++--- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/memos/config.py b/memos/config.py index 938db1d..8b6700c 100644 --- a/memos/config.py +++ b/memos/config.py @@ -10,6 +10,7 @@ from pydantic_settings import ( from pydantic import BaseModel, SecretStr import yaml from collections import OrderedDict +import io class VLMSettings(BaseModel): @@ -19,6 +20,7 @@ class VLMSettings(BaseModel): token: str = "" concurrency: int = 1 force_jpeg: bool = False + prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等" class OCRSettings(BaseModel): @@ -123,14 +125,21 @@ def create_default_config(): if not config_path.exists(): settings = Settings() os.makedirs(config_path.parent, exist_ok=True) - with open(config_path, "w") as f: - # Convert settings to a dictionary and ensure order - settings_dict = settings.model_dump() - ordered_settings = OrderedDict( - (key, settings_dict[key]) for key in settings.model_fields.keys() - ) - yaml.dump(ordered_settings, f, Dumper=yaml.Dumper) - + + # 将设置转换为字典并确保顺序 + settings_dict = settings.model_dump() + ordered_settings = OrderedDict( + (key, settings_dict[key]) for key in settings.model_fields.keys() + ) + + # 使用 io.StringIO 作为中间步骤 + with io.StringIO() as string_buffer: + yaml.dump(ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper) + yaml_content = string_buffer.getvalue() + + # 将内容写入文件,确保使用 UTF-8 编码 + with open(config_path, "w", encoding="utf-8") as f: + f.write(yaml_content) # Create default config if it doesn't exist create_default_config() diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index e84d7c6..00d2e7d 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -13,7 +13,6 @@ import numpy as np PLUGIN_NAME = "vlm" -PROMPT = "请帮我尽量详尽的描述这个图片中的内容,包括文字内容、视觉元素等" router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}}) @@ -23,6 +22,7 @@ token = None concurrency = None semaphore = None force_jpeg = None +prompt = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -133,7 +133,7 @@ async def predict_remote( "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{img_base64}"}, }, - {"type": "text", "text": PROMPT}, + {"type": "text", "text": prompt}, # Use the global prompt variable here ], } ], @@ -224,13 +224,14 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): - global modelname, endpoint, token, concurrency, semaphore, force_jpeg + global modelname, endpoint, token, concurrency, semaphore, force_jpeg, prompt modelname = config.modelname endpoint = config.endpoint token = config.token concurrency = config.concurrency force_jpeg = config.force_jpeg + prompt = config.prompt semaphore = asyncio.Semaphore(concurrency) # Print the parameters @@ -240,4 +241,5 @@ def init_plugin(config): logger.info(f"Token: {token}") logger.info(f"Concurrency: {concurrency}") logger.info(f"Force JPEG: {force_jpeg}") + logger.info(f"Prompt: {prompt}")