mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
fix: skip flash attn
https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d
This commit is contained in:
parent
4b189c22d2
commit
41c7e136d9
@ -12,6 +12,18 @@ import io
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
|
||||
from unittest.mock import patch
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
|
||||
|
||||
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
||||
if not str(filename).endswith("modeling_florence2.py"):
|
||||
return get_imports(filename)
|
||||
imports = get_imports(filename)
|
||||
imports.remove("flash_attn")
|
||||
return imports
|
||||
|
||||
|
||||
PLUGIN_NAME = "vlm"
|
||||
PROMPT = "描述这张图片的内容"
|
||||
|
||||
@ -251,7 +263,7 @@ async def vlm(entity: Entity, request: Request):
|
||||
|
||||
def init_plugin(config):
|
||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
|
||||
|
||||
|
||||
modelname = config.modelname
|
||||
endpoint = config.endpoint
|
||||
token = config.token
|
||||
@ -279,15 +291,16 @@ def init_plugin(config):
|
||||
)
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft",
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=True,
|
||||
).to(device)
|
||||
florence_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
||||
)
|
||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft",
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=True,
|
||||
).to(device)
|
||||
florence_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
||||
)
|
||||
logger.info("Florence model and processor initialized")
|
||||
|
||||
# Print the parameters
|
||||
|
Loading…
x
Reference in New Issue
Block a user