mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-10 04:57:12 +00:00
feat(vlm): make prompt a config
This commit is contained in:
parent
b3ebb2c92d
commit
86a73ca72e
@ -10,6 +10,7 @@ from pydantic_settings import (
|
|||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
import yaml
|
import yaml
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
class VLMSettings(BaseModel):
|
class VLMSettings(BaseModel):
|
||||||
@ -19,6 +20,7 @@ class VLMSettings(BaseModel):
|
|||||||
token: str = ""
|
token: str = ""
|
||||||
concurrency: int = 1
|
concurrency: int = 1
|
||||||
force_jpeg: bool = False
|
force_jpeg: bool = False
|
||||||
|
prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等"
|
||||||
|
|
||||||
|
|
||||||
class OCRSettings(BaseModel):
|
class OCRSettings(BaseModel):
|
||||||
@ -123,14 +125,21 @@ def create_default_config():
|
|||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
os.makedirs(config_path.parent, exist_ok=True)
|
os.makedirs(config_path.parent, exist_ok=True)
|
||||||
with open(config_path, "w") as f:
|
|
||||||
# Convert settings to a dictionary and ensure order
|
# 将设置转换为字典并确保顺序
|
||||||
settings_dict = settings.model_dump()
|
settings_dict = settings.model_dump()
|
||||||
ordered_settings = OrderedDict(
|
ordered_settings = OrderedDict(
|
||||||
(key, settings_dict[key]) for key in settings.model_fields.keys()
|
(key, settings_dict[key]) for key in settings.model_fields.keys()
|
||||||
)
|
)
|
||||||
yaml.dump(ordered_settings, f, Dumper=yaml.Dumper)
|
|
||||||
|
|
||||||
|
# 使用 io.StringIO 作为中间步骤
|
||||||
|
with io.StringIO() as string_buffer:
|
||||||
|
yaml.dump(ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper)
|
||||||
|
yaml_content = string_buffer.getvalue()
|
||||||
|
|
||||||
|
# 将内容写入文件,确保使用 UTF-8 编码
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(yaml_content)
|
||||||
|
|
||||||
# Create default config if it doesn't exist
|
# Create default config if it doesn't exist
|
||||||
create_default_config()
|
create_default_config()
|
||||||
|
@ -13,7 +13,6 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
PLUGIN_NAME = "vlm"
|
PLUGIN_NAME = "vlm"
|
||||||
PROMPT = "请帮我尽量详尽的描述这个图片中的内容,包括文字内容、视觉元素等"
|
|
||||||
|
|
||||||
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
||||||
|
|
||||||
@ -23,6 +22,7 @@ token = None
|
|||||||
concurrency = None
|
concurrency = None
|
||||||
semaphore = None
|
semaphore = None
|
||||||
force_jpeg = None
|
force_jpeg = None
|
||||||
|
prompt = None
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -133,7 +133,7 @@ async def predict_remote(
|
|||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
|
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
|
||||||
},
|
},
|
||||||
{"type": "text", "text": PROMPT},
|
{"type": "text", "text": prompt}, # Use the global prompt variable here
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -224,13 +224,14 @@ async def vlm(entity: Entity, request: Request):
|
|||||||
|
|
||||||
|
|
||||||
def init_plugin(config):
|
def init_plugin(config):
|
||||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
|
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, prompt
|
||||||
|
|
||||||
modelname = config.modelname
|
modelname = config.modelname
|
||||||
endpoint = config.endpoint
|
endpoint = config.endpoint
|
||||||
token = config.token
|
token = config.token
|
||||||
concurrency = config.concurrency
|
concurrency = config.concurrency
|
||||||
force_jpeg = config.force_jpeg
|
force_jpeg = config.force_jpeg
|
||||||
|
prompt = config.prompt
|
||||||
semaphore = asyncio.Semaphore(concurrency)
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
# Print the parameters
|
# Print the parameters
|
||||||
@ -240,4 +241,5 @@ def init_plugin(config):
|
|||||||
logger.info(f"Token: {token}")
|
logger.info(f"Token: {token}")
|
||||||
logger.info(f"Concurrency: {concurrency}")
|
logger.info(f"Concurrency: {concurrency}")
|
||||||
logger.info(f"Force JPEG: {force_jpeg}")
|
logger.info(f"Force JPEG: {force_jpeg}")
|
||||||
|
logger.info(f"Prompt: {prompt}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user