2024-09-10 01:32:00 +08:00

147 lines
4.1 KiB
Python

import asyncio
from typing import List
from fastapi import APIRouter, HTTPException
import logging
import uvicorn
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from pydantic import BaseModel
from modelscope import snapshot_download
PLUGIN_NAME = "embedding"
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
# Global variables
enabled = False
model = None
num_dim = None
endpoint = None
model_name = None
device = None
use_modelscope = None
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def init_embedding_model():
global model, device, use_modelscope
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
if use_modelscope:
model_dir = snapshot_download(model_name)
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = model_name
logger.info(f"Using model: {model_dir}")
model = SentenceTransformer(model_dir, trust_remote_code=True)
model.to(device)
logger.info(f"Embedding model initialized on device: {device}")
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
embeddings = model.encode(input_texts, convert_to_tensor=True)
embeddings = embeddings.cpu().numpy()
# Normalize embeddings
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
return embeddings.tolist()
class EmbeddingRequest(BaseModel):
input: List[str]
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
@router.get("/")
async def read_root():
return {"healthy": True, "enabled": enabled}
@router.post("", include_in_schema=False)
@router.post("/", response_model=EmbeddingResponse)
async def embed(request: EmbeddingRequest):
try:
if not request.input:
return EmbeddingResponse(embeddings=[])
# Run the embedding generation in a separate thread to avoid blocking
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(None, generate_embeddings, request.input)
return EmbeddingResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error generating embeddings: {str(e)}"
)
def init_plugin(config):
global enabled, num_dim, endpoint, model_name, use_modelscope
enabled = config.enabled
num_dim = config.num_dim
endpoint = config.endpoint
model_name = config.model
use_modelscope = config.use_modelscope
if enabled:
init_embedding_model()
logger.info("Embedding plugin initialized")
logger.info(f"Enabled: {enabled}")
logger.info(f"Number of dimensions: {num_dim}")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Model: {model_name}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__":
import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="Embedding Plugin Configuration")
parser.add_argument(
"--num-dim", type=int, default=768, help="Number of embedding dimensions"
)
parser.add_argument(
"--model",
type=str,
default="jinaai/jina-embeddings-v2-base-zh",
help="Embedding model name",
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
parser.add_argument(
"--use-modelscope", action="store_true", help="Use ModelScope to download the model"
)
args = parser.parse_args()
class Config:
def __init__(self, args):
self.enabled = True
self.num_dim = args.num_dim
self.endpoint = "what ever"
self.model = args.model
self.use_modelscope = args.use_modelscope
init_plugin(Config(args))
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)