mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(vlm): make prompt a config
This commit is contained in:
parent
b3ebb2c92d
commit
86a73ca72e
@ -10,6 +10,7 @@ from pydantic_settings import (
|
||||
from pydantic import BaseModel, SecretStr
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
import io
|
||||
|
||||
|
||||
class VLMSettings(BaseModel):
|
||||
@ -19,6 +20,7 @@ class VLMSettings(BaseModel):
|
||||
token: str = ""
|
||||
concurrency: int = 1
|
||||
force_jpeg: bool = False
|
||||
prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等"
|
||||
|
||||
|
||||
class OCRSettings(BaseModel):
|
||||
@ -123,14 +125,21 @@ def create_default_config():
|
||||
if not config_path.exists():
|
||||
settings = Settings()
|
||||
os.makedirs(config_path.parent, exist_ok=True)
|
||||
with open(config_path, "w") as f:
|
||||
# Convert settings to a dictionary and ensure order
|
||||
settings_dict = settings.model_dump()
|
||||
ordered_settings = OrderedDict(
|
||||
(key, settings_dict[key]) for key in settings.model_fields.keys()
|
||||
)
|
||||
yaml.dump(ordered_settings, f, Dumper=yaml.Dumper)
|
||||
|
||||
|
||||
# 将设置转换为字典并确保顺序
|
||||
settings_dict = settings.model_dump()
|
||||
ordered_settings = OrderedDict(
|
||||
(key, settings_dict[key]) for key in settings.model_fields.keys()
|
||||
)
|
||||
|
||||
# 使用 io.StringIO 作为中间步骤
|
||||
with io.StringIO() as string_buffer:
|
||||
yaml.dump(ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper)
|
||||
yaml_content = string_buffer.getvalue()
|
||||
|
||||
# 将内容写入文件,确保使用 UTF-8 编码
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(yaml_content)
|
||||
|
||||
# Create default config if it doesn't exist
|
||||
create_default_config()
|
||||
|
@ -13,7 +13,6 @@ import numpy as np
|
||||
|
||||
|
||||
PLUGIN_NAME = "vlm"
|
||||
PROMPT = "请帮我尽量详尽的描述这个图片中的内容,包括文字内容、视觉元素等"
|
||||
|
||||
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
||||
|
||||
@ -23,6 +22,7 @@ token = None
|
||||
concurrency = None
|
||||
semaphore = None
|
||||
force_jpeg = None
|
||||
prompt = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -133,7 +133,7 @@ async def predict_remote(
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
|
||||
},
|
||||
{"type": "text", "text": PROMPT},
|
||||
{"type": "text", "text": prompt}, # Use the global prompt variable here
|
||||
],
|
||||
}
|
||||
],
|
||||
@ -224,13 +224,14 @@ async def vlm(entity: Entity, request: Request):
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
|
||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, prompt
|
||||
|
||||
modelname = config.modelname
|
||||
endpoint = config.endpoint
|
||||
token = config.token
|
||||
concurrency = config.concurrency
|
||||
force_jpeg = config.force_jpeg
|
||||
prompt = config.prompt
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
# Print the parameters
|
||||
@ -240,4 +241,5 @@ def init_plugin(config):
|
||||
logger.info(f"Token: {token}")
|
||||
logger.info(f"Concurrency: {concurrency}")
|
||||
logger.info(f"Force JPEG: {force_jpeg}")
|
||||
logger.info(f"Prompt: {prompt}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user