arkohut 2024-09-10 00:31:56 +08:00
parent 4b189c22d2
commit 41c7e136d9

View File

@ -12,6 +12,18 @@ import io
import torch import torch
from transformers import AutoModelForCausalLM, AutoProcessor from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
PLUGIN_NAME = "vlm" PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容" PROMPT = "描述这张图片的内容"
@ -279,6 +291,7 @@ def init_plugin(config):
) )
logger.info(f"Using device: {device}") logger.info(f"Using device: {device}")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained( florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft", "microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,