This reverts commit 5b1194f1bc6ad055fe69a9a2207ef1c31cc73721.
This commit is contained in:
arkohut 2024-09-09 23:27:40 +08:00
parent d08c9e0e36
commit 3912d165f6

View File

@ -12,18 +12,6 @@ import io
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@ -263,7 +251,7 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
modelname = config.modelname
endpoint = config.endpoint
token = config.token
@ -291,16 +279,15 @@ def init_plugin(config):
)
logger.info(f"Using device: {device}")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
)
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
)
logger.info("Florence model and processor initialized")
# Print the parameters