mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
feat: support modelscope
This commit is contained in:
parent
241911d1d2
commit
378e5bf445
@ -51,6 +51,8 @@ class Settings(BaseSettings):
|
|||||||
default_library: str = "screenshots"
|
default_library: str = "screenshots"
|
||||||
screenshots_dir: str = os.path.join(base_dir, "screenshots")
|
screenshots_dir: str = os.path.join(base_dir, "screenshots")
|
||||||
|
|
||||||
|
use_modelscope: bool = False
|
||||||
|
|
||||||
typesense_host: str = "localhost"
|
typesense_host: str = "localhost"
|
||||||
typesense_port: str = "8108"
|
typesense_port: str = "8108"
|
||||||
typesense_protocol: str = "http"
|
typesense_protocol: str = "http"
|
||||||
|
@ -7,6 +7,7 @@ from sentence_transformers import SentenceTransformer
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
PLUGIN_NAME = "embedding"
|
PLUGIN_NAME = "embedding"
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ num_dim = None
|
|||||||
endpoint = None
|
endpoint = None
|
||||||
model_name = None
|
model_name = None
|
||||||
device = None
|
device = None
|
||||||
|
use_modelscope = None
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -26,7 +28,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def init_embedding_model():
|
def init_embedding_model():
|
||||||
global model, device
|
global model, device, use_modelscope
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
@ -34,7 +36,14 @@ def init_embedding_model():
|
|||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
model = SentenceTransformer(model_name, trust_remote_code=True)
|
if use_modelscope:
|
||||||
|
model_dir = snapshot_download(model_name)
|
||||||
|
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
||||||
|
else:
|
||||||
|
model_dir = model_name
|
||||||
|
logger.info(f"Using model: {model_dir}")
|
||||||
|
|
||||||
|
model = SentenceTransformer(model_dir, trust_remote_code=True)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
logger.info(f"Embedding model initialized on device: {device}")
|
logger.info(f"Embedding model initialized on device: {device}")
|
||||||
|
|
||||||
@ -80,11 +89,12 @@ async def embed(request: EmbeddingRequest):
|
|||||||
|
|
||||||
|
|
||||||
def init_plugin(config):
|
def init_plugin(config):
|
||||||
global enabled, num_dim, endpoint, model_name
|
global enabled, num_dim, endpoint, model_name, use_modelscope
|
||||||
enabled = config.enabled
|
enabled = config.enabled
|
||||||
num_dim = config.num_dim
|
num_dim = config.num_dim
|
||||||
endpoint = config.endpoint
|
endpoint = config.endpoint
|
||||||
model_name = config.model
|
model_name = config.model
|
||||||
|
use_modelscope = config.use_modelscope
|
||||||
|
|
||||||
if enabled:
|
if enabled:
|
||||||
init_embedding_model()
|
init_embedding_model()
|
||||||
@ -94,6 +104,7 @@ def init_plugin(config):
|
|||||||
logger.info(f"Number of dimensions: {num_dim}")
|
logger.info(f"Number of dimensions: {num_dim}")
|
||||||
logger.info(f"Endpoint: {endpoint}")
|
logger.info(f"Endpoint: {endpoint}")
|
||||||
logger.info(f"Model: {model_name}")
|
logger.info(f"Model: {model_name}")
|
||||||
|
logger.info(f"Use ModelScope: {use_modelscope}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -113,6 +124,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=8000, help="Port to run the server on"
|
"--port", type=int, default=8000, help="Port to run the server on"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-modelscope", action="store_true", help="Use ModelScope to download the model"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -122,6 +136,7 @@ if __name__ == "__main__":
|
|||||||
self.num_dim = args.num_dim
|
self.num_dim = args.num_dim
|
||||||
self.endpoint = "what ever"
|
self.endpoint = "what ever"
|
||||||
self.model = args.model
|
self.model = args.model
|
||||||
|
self.use_modelscope = args.use_modelscope
|
||||||
|
|
||||||
init_plugin(Config(args))
|
init_plugin(Config(args))
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from transformers import AutoModelForCausalLM, AutoProcessor
|
|||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from transformers.dynamic_module_utils import get_imports
|
from transformers.dynamic_module_utils import get_imports
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
||||||
if not str(filename).endswith("modeling_florence2.py"):
|
if not str(filename).endswith("modeling_florence2.py"):
|
||||||
@ -270,6 +270,7 @@ def init_plugin(config):
|
|||||||
concurrency = config.concurrency
|
concurrency = config.concurrency
|
||||||
force_jpeg = config.force_jpeg
|
force_jpeg = config.force_jpeg
|
||||||
use_local = config.use_local
|
use_local = config.use_local
|
||||||
|
use_modelscope = config.use_modelscope
|
||||||
semaphore = asyncio.Semaphore(concurrency)
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
if use_local:
|
if use_local:
|
||||||
@ -291,15 +292,22 @@ def init_plugin(config):
|
|||||||
)
|
)
|
||||||
logger.info(f"Using device: {device}")
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
if use_modelscope:
|
||||||
|
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
|
||||||
|
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
||||||
|
else:
|
||||||
|
model_dir = "microsoft/Florence-2-base-ft"
|
||||||
|
logger.info(f"Using model: {model_dir}")
|
||||||
|
|
||||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||||
"microsoft/Florence-2-base-ft",
|
model_dir,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
attn_implementation="sdpa",
|
attn_implementation="sdpa",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
).to(device)
|
).to(device)
|
||||||
florence_processor = AutoProcessor.from_pretrained(
|
florence_processor = AutoProcessor.from_pretrained(
|
||||||
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
model_dir, trust_remote_code=True
|
||||||
)
|
)
|
||||||
logger.info("Florence model and processor initialized")
|
logger.info("Florence model and processor initialized")
|
||||||
|
|
||||||
@ -311,6 +319,7 @@ def init_plugin(config):
|
|||||||
logger.info(f"Concurrency: {concurrency}")
|
logger.info(f"Concurrency: {concurrency}")
|
||||||
logger.info(f"Force JPEG: {force_jpeg}")
|
logger.info(f"Force JPEG: {force_jpeg}")
|
||||||
logger.info(f"Use Local: {use_local}")
|
logger.info(f"Use Local: {use_local}")
|
||||||
|
logger.info(f"Use ModelScope: {use_modelscope}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -329,6 +338,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=8000, help="Port to run the server on"
|
"--port", type=int, default=8000, help="Port to run the server on"
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ dependencies = [
|
|||||||
"numpy",
|
"numpy",
|
||||||
"timm",
|
"timm",
|
||||||
"einops",
|
"einops",
|
||||||
|
"modelscope",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user