feat: support modelscope

This commit is contained in:
arkohut 2024-09-10 01:32:00 +08:00
parent 241911d1d2
commit 378e5bf445
4 changed files with 34 additions and 6 deletions

View File

@ -51,6 +51,8 @@ class Settings(BaseSettings):
default_library: str = "screenshots"
screenshots_dir: str = os.path.join(base_dir, "screenshots")
use_modelscope: bool = False
typesense_host: str = "localhost"
typesense_port: str = "8108"
typesense_protocol: str = "http"

View File

@ -7,6 +7,7 @@ from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from pydantic import BaseModel
from modelscope import snapshot_download
PLUGIN_NAME = "embedding"
@ -19,6 +20,7 @@ num_dim = None
endpoint = None
model_name = None
device = None
use_modelscope = None
# Configure logger
logging.basicConfig(level=logging.INFO)
@ -26,7 +28,7 @@ logger = logging.getLogger(__name__)
def init_embedding_model():
global model, device
global model, device, use_modelscope
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
@ -34,7 +36,14 @@ def init_embedding_model():
else:
device = torch.device("cpu")
model = SentenceTransformer(model_name, trust_remote_code=True)
if use_modelscope:
model_dir = snapshot_download(model_name)
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = model_name
logger.info(f"Using model: {model_dir}")
model = SentenceTransformer(model_dir, trust_remote_code=True)
model.to(device)
logger.info(f"Embedding model initialized on device: {device}")
@ -80,11 +89,12 @@ async def embed(request: EmbeddingRequest):
def init_plugin(config):
global enabled, num_dim, endpoint, model_name
global enabled, num_dim, endpoint, model_name, use_modelscope
enabled = config.enabled
num_dim = config.num_dim
endpoint = config.endpoint
model_name = config.model
use_modelscope = config.use_modelscope
if enabled:
init_embedding_model()
@ -94,6 +104,7 @@ def init_plugin(config):
logger.info(f"Number of dimensions: {num_dim}")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Model: {model_name}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__":
@ -113,6 +124,9 @@ if __name__ == "__main__":
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
parser.add_argument(
"--use-modelscope", action="store_true", help="Use ModelScope to download the model"
)
args = parser.parse_args()
@ -122,6 +136,7 @@ if __name__ == "__main__":
self.num_dim = args.num_dim
self.endpoint = "what ever"
self.model = args.model
self.use_modelscope = args.use_modelscope
init_plugin(Config(args))

View File

@ -14,7 +14,7 @@ from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
from modelscope import snapshot_download
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
@ -270,6 +270,7 @@ def init_plugin(config):
concurrency = config.concurrency
force_jpeg = config.force_jpeg
use_local = config.use_local
use_modelscope = config.use_modelscope
semaphore = asyncio.Semaphore(concurrency)
if use_local:
@ -291,15 +292,22 @@ def init_plugin(config):
)
logger.info(f"Using device: {device}")
if use_modelscope:
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = "microsoft/Florence-2-base-ft"
logger.info(f"Using model: {model_dir}")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft",
model_dir,
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
model_dir, trust_remote_code=True
)
logger.info("Florence model and processor initialized")
@ -311,6 +319,7 @@ def init_plugin(config):
logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}")
logger.info(f"Use Local: {use_local}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__":
@ -329,6 +338,7 @@ if __name__ == "__main__":
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
args = parser.parse_args()

View File

@ -41,6 +41,7 @@ dependencies = [
"numpy",
"timm",
"einops",
"modelscope",
]
[project.urls]