From 378e5bf445b800f2928b2d5363c846474295b50b Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 01:32:00 +0800 Subject: [PATCH] feat: support modelscope --- memos/config.py | 2 ++ memos/plugins/embedding/main.py | 21 ++++++++++++++++++--- memos/plugins/vlm/main.py | 16 +++++++++++++--- pyproject.toml | 1 + 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/memos/config.py b/memos/config.py index a572e6f..6986b7a 100644 --- a/memos/config.py +++ b/memos/config.py @@ -51,6 +51,8 @@ class Settings(BaseSettings): default_library: str = "screenshots" screenshots_dir: str = os.path.join(base_dir, "screenshots") + use_modelscope: bool = False + typesense_host: str = "localhost" typesense_port: str = "8108" typesense_protocol: str = "http" diff --git a/memos/plugins/embedding/main.py b/memos/plugins/embedding/main.py index 89e837a..c7fbec4 100644 --- a/memos/plugins/embedding/main.py +++ b/memos/plugins/embedding/main.py @@ -7,6 +7,7 @@ from sentence_transformers import SentenceTransformer import torch import numpy as np from pydantic import BaseModel +from modelscope import snapshot_download PLUGIN_NAME = "embedding" @@ -19,6 +20,7 @@ num_dim = None endpoint = None model_name = None device = None +use_modelscope = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -26,7 +28,7 @@ logger = logging.getLogger(__name__) def init_embedding_model(): - global model, device + global model, device, use_modelscope if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): @@ -34,7 +36,14 @@ def init_embedding_model(): else: device = torch.device("cpu") - model = SentenceTransformer(model_name, trust_remote_code=True) + if use_modelscope: + model_dir = snapshot_download(model_name) + logger.info(f"Model downloaded from ModelScope to: {model_dir}") + else: + model_dir = model_name + logger.info(f"Using model: {model_dir}") + + model = SentenceTransformer(model_dir, trust_remote_code=True) model.to(device) logger.info(f"Embedding model initialized on device: {device}") @@ -80,11 +89,12 @@ async def embed(request: EmbeddingRequest): def init_plugin(config): - global enabled, num_dim, endpoint, model_name + global enabled, num_dim, endpoint, model_name, use_modelscope enabled = config.enabled num_dim = config.num_dim endpoint = config.endpoint model_name = config.model + use_modelscope = config.use_modelscope if enabled: init_embedding_model() @@ -94,6 +104,7 @@ def init_plugin(config): logger.info(f"Number of dimensions: {num_dim}") logger.info(f"Endpoint: {endpoint}") logger.info(f"Model: {model_name}") + logger.info(f"Use ModelScope: {use_modelscope}") if __name__ == "__main__": @@ -113,6 +124,9 @@ if __name__ == "__main__": parser.add_argument( "--port", type=int, default=8000, help="Port to run the server on" ) + parser.add_argument( + "--use-modelscope", action="store_true", help="Use ModelScope to download the model" + ) args = parser.parse_args() @@ -122,6 +136,7 @@ if __name__ == "__main__": self.num_dim = args.num_dim self.endpoint = "what ever" self.model = args.model + self.use_modelscope = args.use_modelscope init_plugin(Config(args)) diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index 6190a73..fc905a4 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -14,7 +14,7 @@ from transformers import AutoModelForCausalLM, AutoProcessor from unittest.mock import patch from transformers.dynamic_module_utils import get_imports - +from modelscope import snapshot_download def fixed_get_imports(filename: str | os.PathLike) -> list[str]: if not str(filename).endswith("modeling_florence2.py"): @@ -270,6 +270,7 @@ def init_plugin(config): concurrency = config.concurrency force_jpeg = config.force_jpeg use_local = config.use_local + use_modelscope = config.use_modelscope semaphore = asyncio.Semaphore(concurrency) if use_local: @@ -291,15 +292,22 @@ def init_plugin(config): ) logger.info(f"Using device: {device}") + if use_modelscope: + model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft') + logger.info(f"Model downloaded from ModelScope to: {model_dir}") + else: + model_dir = "microsoft/Florence-2-base-ft" + logger.info(f"Using model: {model_dir}") + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", + model_dir, torch_dtype=torch_dtype, attn_implementation="sdpa", trust_remote_code=True, ).to(device) florence_processor = AutoProcessor.from_pretrained( - "microsoft/Florence-2-base-ft", trust_remote_code=True + model_dir, trust_remote_code=True ) logger.info("Florence model and processor initialized") @@ -311,6 +319,7 @@ def init_plugin(config): logger.info(f"Concurrency: {concurrency}") logger.info(f"Force JPEG: {force_jpeg}") logger.info(f"Use Local: {use_local}") + logger.info(f"Use ModelScope: {use_modelscope}") if __name__ == "__main__": @@ -329,6 +338,7 @@ if __name__ == "__main__": parser.add_argument( "--port", type=int, default=8000, help="Port to run the server on" ) + parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model") args = parser.parse_args() diff --git a/pyproject.toml b/pyproject.toml index 1bcac45..5e44755 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "numpy", "timm", "einops", + "modelscope", ] [project.urls]