feat: support modelscope

This commit is contained in:
arkohut 2024-09-10 01:32:00 +08:00
parent 241911d1d2
commit 378e5bf445
4 changed files with 34 additions and 6 deletions

View File

@ -51,6 +51,8 @@ class Settings(BaseSettings):
default_library: str = "screenshots" default_library: str = "screenshots"
screenshots_dir: str = os.path.join(base_dir, "screenshots") screenshots_dir: str = os.path.join(base_dir, "screenshots")
use_modelscope: bool = False
typesense_host: str = "localhost" typesense_host: str = "localhost"
typesense_port: str = "8108" typesense_port: str = "8108"
typesense_protocol: str = "http" typesense_protocol: str = "http"

View File

@ -7,6 +7,7 @@ from sentence_transformers import SentenceTransformer
import torch import torch
import numpy as np import numpy as np
from pydantic import BaseModel from pydantic import BaseModel
from modelscope import snapshot_download
PLUGIN_NAME = "embedding" PLUGIN_NAME = "embedding"
@ -19,6 +20,7 @@ num_dim = None
endpoint = None endpoint = None
model_name = None model_name = None
device = None device = None
use_modelscope = None
# Configure logger # Configure logger
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -26,7 +28,7 @@ logger = logging.getLogger(__name__)
def init_embedding_model(): def init_embedding_model():
global model, device global model, device, use_modelscope
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
@ -34,7 +36,14 @@ def init_embedding_model():
else: else:
device = torch.device("cpu") device = torch.device("cpu")
model = SentenceTransformer(model_name, trust_remote_code=True) if use_modelscope:
model_dir = snapshot_download(model_name)
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = model_name
logger.info(f"Using model: {model_dir}")
model = SentenceTransformer(model_dir, trust_remote_code=True)
model.to(device) model.to(device)
logger.info(f"Embedding model initialized on device: {device}") logger.info(f"Embedding model initialized on device: {device}")
@ -80,11 +89,12 @@ async def embed(request: EmbeddingRequest):
def init_plugin(config): def init_plugin(config):
global enabled, num_dim, endpoint, model_name global enabled, num_dim, endpoint, model_name, use_modelscope
enabled = config.enabled enabled = config.enabled
num_dim = config.num_dim num_dim = config.num_dim
endpoint = config.endpoint endpoint = config.endpoint
model_name = config.model model_name = config.model
use_modelscope = config.use_modelscope
if enabled: if enabled:
init_embedding_model() init_embedding_model()
@ -94,6 +104,7 @@ def init_plugin(config):
logger.info(f"Number of dimensions: {num_dim}") logger.info(f"Number of dimensions: {num_dim}")
logger.info(f"Endpoint: {endpoint}") logger.info(f"Endpoint: {endpoint}")
logger.info(f"Model: {model_name}") logger.info(f"Model: {model_name}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__": if __name__ == "__main__":
@ -113,6 +124,9 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on" "--port", type=int, default=8000, help="Port to run the server on"
) )
parser.add_argument(
"--use-modelscope", action="store_true", help="Use ModelScope to download the model"
)
args = parser.parse_args() args = parser.parse_args()
@ -122,6 +136,7 @@ if __name__ == "__main__":
self.num_dim = args.num_dim self.num_dim = args.num_dim
self.endpoint = "what ever" self.endpoint = "what ever"
self.model = args.model self.model = args.model
self.use_modelscope = args.use_modelscope
init_plugin(Config(args)) init_plugin(Config(args))

View File

@ -14,7 +14,7 @@ from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports from transformers.dynamic_module_utils import get_imports
from modelscope import snapshot_download
def fixed_get_imports(filename: str | os.PathLike) -> list[str]: def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"): if not str(filename).endswith("modeling_florence2.py"):
@ -270,6 +270,7 @@ def init_plugin(config):
concurrency = config.concurrency concurrency = config.concurrency
force_jpeg = config.force_jpeg force_jpeg = config.force_jpeg
use_local = config.use_local use_local = config.use_local
use_modelscope = config.use_modelscope
semaphore = asyncio.Semaphore(concurrency) semaphore = asyncio.Semaphore(concurrency)
if use_local: if use_local:
@ -291,15 +292,22 @@ def init_plugin(config):
) )
logger.info(f"Using device: {device}") logger.info(f"Using device: {device}")
if use_modelscope:
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = "microsoft/Florence-2-base-ft"
logger.info(f"Using model: {model_dir}")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained( florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft", model_dir,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
attn_implementation="sdpa", attn_implementation="sdpa",
trust_remote_code=True, trust_remote_code=True,
).to(device) ).to(device)
florence_processor = AutoProcessor.from_pretrained( florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True model_dir, trust_remote_code=True
) )
logger.info("Florence model and processor initialized") logger.info("Florence model and processor initialized")
@ -311,6 +319,7 @@ def init_plugin(config):
logger.info(f"Concurrency: {concurrency}") logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}") logger.info(f"Force JPEG: {force_jpeg}")
logger.info(f"Use Local: {use_local}") logger.info(f"Use Local: {use_local}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__": if __name__ == "__main__":
@ -329,6 +338,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on" "--port", type=int, default=8000, help="Port to run the server on"
) )
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
args = parser.parse_args() args = parser.parse_args()

View File

@ -41,6 +41,7 @@ dependencies = [
"numpy", "numpy",
"timm", "timm",
"einops", "einops",
"modelscope",
] ]
[project.urls] [project.urls]