mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat: support modelscope
This commit is contained in:
parent
241911d1d2
commit
378e5bf445
@ -51,6 +51,8 @@ class Settings(BaseSettings):
|
||||
default_library: str = "screenshots"
|
||||
screenshots_dir: str = os.path.join(base_dir, "screenshots")
|
||||
|
||||
use_modelscope: bool = False
|
||||
|
||||
typesense_host: str = "localhost"
|
||||
typesense_port: str = "8108"
|
||||
typesense_protocol: str = "http"
|
||||
|
@ -7,6 +7,7 @@ from sentence_transformers import SentenceTransformer
|
||||
import torch
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
from modelscope import snapshot_download
|
||||
|
||||
PLUGIN_NAME = "embedding"
|
||||
|
||||
@ -19,6 +20,7 @@ num_dim = None
|
||||
endpoint = None
|
||||
model_name = None
|
||||
device = None
|
||||
use_modelscope = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -26,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_embedding_model():
|
||||
global model, device
|
||||
global model, device, use_modelscope
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
@ -34,7 +36,14 @@ def init_embedding_model():
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
model = SentenceTransformer(model_name, trust_remote_code=True)
|
||||
if use_modelscope:
|
||||
model_dir = snapshot_download(model_name)
|
||||
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
||||
else:
|
||||
model_dir = model_name
|
||||
logger.info(f"Using model: {model_dir}")
|
||||
|
||||
model = SentenceTransformer(model_dir, trust_remote_code=True)
|
||||
model.to(device)
|
||||
logger.info(f"Embedding model initialized on device: {device}")
|
||||
|
||||
@ -80,11 +89,12 @@ async def embed(request: EmbeddingRequest):
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global enabled, num_dim, endpoint, model_name
|
||||
global enabled, num_dim, endpoint, model_name, use_modelscope
|
||||
enabled = config.enabled
|
||||
num_dim = config.num_dim
|
||||
endpoint = config.endpoint
|
||||
model_name = config.model
|
||||
use_modelscope = config.use_modelscope
|
||||
|
||||
if enabled:
|
||||
init_embedding_model()
|
||||
@ -94,6 +104,7 @@ def init_plugin(config):
|
||||
logger.info(f"Number of dimensions: {num_dim}")
|
||||
logger.info(f"Endpoint: {endpoint}")
|
||||
logger.info(f"Model: {model_name}")
|
||||
logger.info(f"Use ModelScope: {use_modelscope}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -113,6 +124,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-modelscope", action="store_true", help="Use ModelScope to download the model"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -122,6 +136,7 @@ if __name__ == "__main__":
|
||||
self.num_dim = args.num_dim
|
||||
self.endpoint = "what ever"
|
||||
self.model = args.model
|
||||
self.use_modelscope = args.use_modelscope
|
||||
|
||||
init_plugin(Config(args))
|
||||
|
||||
|
@ -14,7 +14,7 @@ from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
|
||||
from unittest.mock import patch
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
|
||||
from modelscope import snapshot_download
|
||||
|
||||
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
||||
if not str(filename).endswith("modeling_florence2.py"):
|
||||
@ -270,6 +270,7 @@ def init_plugin(config):
|
||||
concurrency = config.concurrency
|
||||
force_jpeg = config.force_jpeg
|
||||
use_local = config.use_local
|
||||
use_modelscope = config.use_modelscope
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
if use_local:
|
||||
@ -291,15 +292,22 @@ def init_plugin(config):
|
||||
)
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
if use_modelscope:
|
||||
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
|
||||
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
||||
else:
|
||||
model_dir = "microsoft/Florence-2-base-ft"
|
||||
logger.info(f"Using model: {model_dir}")
|
||||
|
||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft",
|
||||
model_dir,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=True,
|
||||
).to(device)
|
||||
florence_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
||||
model_dir, trust_remote_code=True
|
||||
)
|
||||
logger.info("Florence model and processor initialized")
|
||||
|
||||
@ -311,6 +319,7 @@ def init_plugin(config):
|
||||
logger.info(f"Concurrency: {concurrency}")
|
||||
logger.info(f"Force JPEG: {force_jpeg}")
|
||||
logger.info(f"Use Local: {use_local}")
|
||||
logger.info(f"Use ModelScope: {use_modelscope}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -329,6 +338,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -41,6 +41,7 @@ dependencies = [
|
||||
"numpy",
|
||||
"timm",
|
||||
"einops",
|
||||
"modelscope",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
Loading…
x
Reference in New Issue
Block a user