arkohut 2024-09-09 22:26:48 +08:00
parent 6107c22def
commit 5b1194f1bc

View File

@ -12,6 +12,18 @@ import io
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@ -251,7 +263,7 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
modelname = config.modelname
endpoint = config.endpoint
token = config.token
@ -279,15 +291,16 @@ def init_plugin(config):
)
logger.info(f"Using device: {device}")
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
)
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
)
logger.info("Florence model and processor initialized")
# Print the parameters