diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index ad636b6..6190a73 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -12,6 +12,18 @@ import io import torch from transformers import AutoModelForCausalLM, AutoProcessor +from unittest.mock import patch +from transformers.dynamic_module_utils import get_imports + + +def fixed_get_imports(filename: str | os.PathLike) -> list[str]: + if not str(filename).endswith("modeling_florence2.py"): + return get_imports(filename) + imports = get_imports(filename) + imports.remove("flash_attn") + return imports + + PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" @@ -251,7 +263,7 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype - + modelname = config.modelname endpoint = config.endpoint token = config.token @@ -279,15 +291,16 @@ def init_plugin(config): ) logger.info(f"Using device: {device}") - florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", - torch_dtype=torch_dtype, - attn_implementation="sdpa", - trust_remote_code=True, - ).to(device) - florence_processor = AutoProcessor.from_pretrained( - "microsoft/Florence-2-base-ft", trust_remote_code=True - ) + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): + florence_model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, + ).to(device) + florence_processor = AutoProcessor.from_pretrained( + "microsoft/Florence-2-base-ft", trust_remote_code=True + ) logger.info("Florence model and processor initialized") # Print the parameters