feat(vlm): make prompt a config

This commit is contained in:
arkohut 2024-10-08 19:07:09 +08:00
parent b3ebb2c92d
commit 86a73ca72e
2 changed files with 22 additions and 11 deletions

View File

@ -10,6 +10,7 @@ from pydantic_settings import (
from pydantic import BaseModel, SecretStr
import yaml
from collections import OrderedDict
import io
class VLMSettings(BaseModel):
@ -19,6 +20,7 @@ class VLMSettings(BaseModel):
token: str = ""
concurrency: int = 1
force_jpeg: bool = False
prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等"
class OCRSettings(BaseModel):
@ -123,14 +125,21 @@ def create_default_config():
if not config_path.exists():
settings = Settings()
os.makedirs(config_path.parent, exist_ok=True)
with open(config_path, "w") as f:
# Convert settings to a dictionary and ensure order
settings_dict = settings.model_dump()
ordered_settings = OrderedDict(
(key, settings_dict[key]) for key in settings.model_fields.keys()
)
yaml.dump(ordered_settings, f, Dumper=yaml.Dumper)
# 将设置转换为字典并确保顺序
settings_dict = settings.model_dump()
ordered_settings = OrderedDict(
(key, settings_dict[key]) for key in settings.model_fields.keys()
)
# 使用 io.StringIO 作为中间步骤
with io.StringIO() as string_buffer:
yaml.dump(ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper)
yaml_content = string_buffer.getvalue()
# 将内容写入文件,确保使用 UTF-8 编码
with open(config_path, "w", encoding="utf-8") as f:
f.write(yaml_content)
# Create default config if it doesn't exist
create_default_config()

View File

@ -13,7 +13,6 @@ import numpy as np
PLUGIN_NAME = "vlm"
PROMPT = "请帮我尽量详尽的描述这个图片中的内容,包括文字内容、视觉元素等"
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
@ -23,6 +22,7 @@ token = None
concurrency = None
semaphore = None
force_jpeg = None
prompt = None
# Configure logger
logging.basicConfig(level=logging.INFO)
@ -133,7 +133,7 @@ async def predict_remote(
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
},
{"type": "text", "text": PROMPT},
{"type": "text", "text": prompt}, # Use the global prompt variable here
],
}
],
@ -224,13 +224,14 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, prompt
modelname = config.modelname
endpoint = config.endpoint
token = config.token
concurrency = config.concurrency
force_jpeg = config.force_jpeg
prompt = config.prompt
semaphore = asyncio.Semaphore(concurrency)
# Print the parameters
@ -240,4 +241,5 @@ def init_plugin(config):
logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}")
logger.info(f"Prompt: {prompt}")