mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-07 19:55:25 +00:00
Revert "fix: skip flash attn https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d"
This reverts commit 5b1194f1bc6ad055fe69a9a2207ef1c31cc73721.
This commit is contained in:
parent
d08c9e0e36
commit
3912d165f6
@ -12,18 +12,6 @@ import io
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
|
||||
from unittest.mock import patch
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
|
||||
|
||||
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
||||
if not str(filename).endswith("modeling_florence2.py"):
|
||||
return get_imports(filename)
|
||||
imports = get_imports(filename)
|
||||
imports.remove("flash_attn")
|
||||
return imports
|
||||
|
||||
|
||||
PLUGIN_NAME = "vlm"
|
||||
PROMPT = "描述这张图片的内容"
|
||||
|
||||
@ -291,7 +279,6 @@ def init_plugin(config):
|
||||
)
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft",
|
||||
torch_dtype=torch_dtype,
|
||||
|
Loading…
x
Reference in New Issue
Block a user