mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
fix: skip flash attn
https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d
This commit is contained in:
parent
4b189c22d2
commit
41c7e136d9
@ -12,6 +12,18 @@ import io
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
from transformers.dynamic_module_utils import get_imports
|
||||||
|
|
||||||
|
|
||||||
|
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
||||||
|
if not str(filename).endswith("modeling_florence2.py"):
|
||||||
|
return get_imports(filename)
|
||||||
|
imports = get_imports(filename)
|
||||||
|
imports.remove("flash_attn")
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
PLUGIN_NAME = "vlm"
|
PLUGIN_NAME = "vlm"
|
||||||
PROMPT = "描述这张图片的内容"
|
PROMPT = "描述这张图片的内容"
|
||||||
|
|
||||||
@ -251,7 +263,7 @@ async def vlm(entity: Entity, request: Request):
|
|||||||
|
|
||||||
def init_plugin(config):
|
def init_plugin(config):
|
||||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
|
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
|
||||||
|
|
||||||
modelname = config.modelname
|
modelname = config.modelname
|
||||||
endpoint = config.endpoint
|
endpoint = config.endpoint
|
||||||
token = config.token
|
token = config.token
|
||||||
@ -279,15 +291,16 @@ def init_plugin(config):
|
|||||||
)
|
)
|
||||||
logger.info(f"Using device: {device}")
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||||
"microsoft/Florence-2-base-ft",
|
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||||
torch_dtype=torch_dtype,
|
"microsoft/Florence-2-base-ft",
|
||||||
attn_implementation="sdpa",
|
torch_dtype=torch_dtype,
|
||||||
trust_remote_code=True,
|
attn_implementation="sdpa",
|
||||||
).to(device)
|
trust_remote_code=True,
|
||||||
florence_processor = AutoProcessor.from_pretrained(
|
).to(device)
|
||||||
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
florence_processor = AutoProcessor.from_pretrained(
|
||||||
)
|
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
||||||
|
)
|
||||||
logger.info("Florence model and processor initialized")
|
logger.info("Florence model and processor initialized")
|
||||||
|
|
||||||
# Print the parameters
|
# Print the parameters
|
||||||
|
Loading…
x
Reference in New Issue
Block a user