From 3912d165f67a3d5ef50b43190811624a1eef38bc Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 23:27:40 +0800 Subject: [PATCH] Revert "fix: skip flash attn https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d" This reverts commit 5b1194f1bc6ad055fe69a9a2207ef1c31cc73721. --- memos/plugins/vlm/main.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index 6190a73..ad636b6 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -12,18 +12,6 @@ import io import torch from transformers import AutoModelForCausalLM, AutoProcessor -from unittest.mock import patch -from transformers.dynamic_module_utils import get_imports - - -def fixed_get_imports(filename: str | os.PathLike) -> list[str]: - if not str(filename).endswith("modeling_florence2.py"): - return get_imports(filename) - imports = get_imports(filename) - imports.remove("flash_attn") - return imports - - PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" @@ -263,7 +251,7 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype - + modelname = config.modelname endpoint = config.endpoint token = config.token @@ -291,16 +279,15 @@ def init_plugin(config): ) logger.info(f"Using device: {device}") - with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): - florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", - torch_dtype=torch_dtype, - attn_implementation="sdpa", - trust_remote_code=True, - ).to(device) - florence_processor = AutoProcessor.from_pretrained( - "microsoft/Florence-2-base-ft", trust_remote_code=True - ) + florence_model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, + ).to(device) + florence_processor = AutoProcessor.from_pretrained( + "microsoft/Florence-2-base-ft", trust_remote_code=True + ) logger.info("Florence model and processor initialized") # Print the parameters