arkohut 2024-09-09 22:26:48 +08:00
parent 6107c22def
commit 5b1194f1bc

View File

@ -12,6 +12,18 @@ import io
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@ -279,6 +291,7 @@ def init_plugin(config):
)
logger.info(f"Using device: {device}")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype,