From 7b0635ebf644bdda63e9b748e4c84ef75dfd4eda Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:18:57 +0800 Subject: [PATCH] feat(embedding): make embedding plugin to be a core function --- memos/embedding.py | 54 ++++++++++++ memos/indexing.py | 27 ++---- memos/plugins/embedding/main.py | 146 -------------------------------- memos/server.py | 6 -- 4 files changed, 61 insertions(+), 172 deletions(-) create mode 100644 memos/embedding.py delete mode 100644 memos/plugins/embedding/main.py diff --git a/memos/embedding.py b/memos/embedding.py new file mode 100644 index 0000000..adb407d --- /dev/null +++ b/memos/embedding.py @@ -0,0 +1,54 @@ +from typing import List +from sentence_transformers import SentenceTransformer +import torch +import numpy as np +from modelscope import snapshot_download +from .config import settings +import logging + +# Configure logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Global variables +model = None +device = None + +def init_embedding_model(): + global model, device + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + if settings.embedding.use_modelscope: + model_dir = snapshot_download(settings.embedding.model) + logger.info(f"Model downloaded from ModelScope to: {model_dir}") + else: + model_dir = settings.embedding.model + logger.info(f"Using model: {model_dir}") + + model = SentenceTransformer(model_dir, trust_remote_code=True) + model.to(device) + logger.info(f"Embedding model initialized on device: {device}") + +def generate_embeddings(texts: List[str]) -> List[List[float]]: + global model + + if model is None: + init_embedding_model() + + if not texts: + return [] + + embeddings = model.encode(texts, convert_to_tensor=True) + embeddings = embeddings.cpu().numpy() + + # Normalize embeddings + norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) + norms[norms == 0] = 1 + embeddings = embeddings / norms + + return embeddings.tolist() \ No newline at end of file diff --git a/memos/indexing.py b/memos/indexing.py index 3ee6716..f4446a4 100644 --- a/memos/indexing.py +++ b/memos/indexing.py @@ -17,7 +17,7 @@ from .schemas import ( RequestParams, ) from .config import settings, TYPESENSE_COLLECTION_NAME - +from .embedding import generate_embeddings def convert_metadata_value(metadata: EntityMetadata): if metadata.data_type == MetadataType.JSON_DATA: @@ -49,28 +49,15 @@ def parse_date_fields(entity): async def get_embeddings(texts: List[str]) -> List[List[float]]: print(f"Getting embeddings for {len(texts)} texts") - if settings.embedding.enabled: - endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed" - else: - endpoint = settings.embedding.endpoint - - model = settings.embedding.model - async with httpx.AsyncClient() as client: - response = await client.post( - endpoint, - json={"model": model, "input": texts}, - timeout=30 - ) - if response.status_code == 200: - print("Successfully retrieved embeddings from the embedding service.") + try: + embeddings = generate_embeddings(texts) + print("Successfully generated embeddings.") return [ [round(float(x), 5) for x in embedding] - for embedding in response.json()["embeddings"] + for embedding in embeddings ] - else: - raise Exception( - f"Failed to get embeddings: {response.text} {response.status_code}" - ) + except Exception as e: + raise Exception(f"Failed to generate embeddings: {str(e)}") def generate_metadata_text(metadata_entries): diff --git a/memos/plugins/embedding/main.py b/memos/plugins/embedding/main.py deleted file mode 100644 index c7fbec4..0000000 --- a/memos/plugins/embedding/main.py +++ /dev/null @@ -1,146 +0,0 @@ -import asyncio -from typing import List -from fastapi import APIRouter, HTTPException -import logging -import uvicorn -from sentence_transformers import SentenceTransformer -import torch -import numpy as np -from pydantic import BaseModel -from modelscope import snapshot_download - -PLUGIN_NAME = "embedding" - -router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}}) - -# Global variables -enabled = False -model = None -num_dim = None -endpoint = None -model_name = None -device = None -use_modelscope = None - -# Configure logger -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def init_embedding_model(): - global model, device, use_modelscope - if torch.cuda.is_available(): - device = torch.device("cuda") - elif torch.backends.mps.is_available(): - device = torch.device("mps") - else: - device = torch.device("cpu") - - if use_modelscope: - model_dir = snapshot_download(model_name) - logger.info(f"Model downloaded from ModelScope to: {model_dir}") - else: - model_dir = model_name - logger.info(f"Using model: {model_dir}") - - model = SentenceTransformer(model_dir, trust_remote_code=True) - model.to(device) - logger.info(f"Embedding model initialized on device: {device}") - - -def generate_embeddings(input_texts: List[str]) -> List[List[float]]: - embeddings = model.encode(input_texts, convert_to_tensor=True) - embeddings = embeddings.cpu().numpy() - # Normalize embeddings - norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) - norms[norms == 0] = 1 - embeddings = embeddings / norms - return embeddings.tolist() - - -class EmbeddingRequest(BaseModel): - input: List[str] - - -class EmbeddingResponse(BaseModel): - embeddings: List[List[float]] - - -@router.get("/") -async def read_root(): - return {"healthy": True, "enabled": enabled} - - -@router.post("", include_in_schema=False) -@router.post("/", response_model=EmbeddingResponse) -async def embed(request: EmbeddingRequest): - try: - if not request.input: - return EmbeddingResponse(embeddings=[]) - - # Run the embedding generation in a separate thread to avoid blocking - loop = asyncio.get_event_loop() - embeddings = await loop.run_in_executor(None, generate_embeddings, request.input) - return EmbeddingResponse(embeddings=embeddings) - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error generating embeddings: {str(e)}" - ) - - -def init_plugin(config): - global enabled, num_dim, endpoint, model_name, use_modelscope - enabled = config.enabled - num_dim = config.num_dim - endpoint = config.endpoint - model_name = config.model - use_modelscope = config.use_modelscope - - if enabled: - init_embedding_model() - - logger.info("Embedding plugin initialized") - logger.info(f"Enabled: {enabled}") - logger.info(f"Number of dimensions: {num_dim}") - logger.info(f"Endpoint: {endpoint}") - logger.info(f"Model: {model_name}") - logger.info(f"Use ModelScope: {use_modelscope}") - - -if __name__ == "__main__": - import argparse - from fastapi import FastAPI - - parser = argparse.ArgumentParser(description="Embedding Plugin Configuration") - parser.add_argument( - "--num-dim", type=int, default=768, help="Number of embedding dimensions" - ) - parser.add_argument( - "--model", - type=str, - default="jinaai/jina-embeddings-v2-base-zh", - help="Embedding model name", - ) - parser.add_argument( - "--port", type=int, default=8000, help="Port to run the server on" - ) - parser.add_argument( - "--use-modelscope", action="store_true", help="Use ModelScope to download the model" - ) - - args = parser.parse_args() - - class Config: - def __init__(self, args): - self.enabled = True - self.num_dim = args.num_dim - self.endpoint = "what ever" - self.model = args.model - self.use_modelscope = args.use_modelscope - - init_plugin(Config(args)) - - app = FastAPI() - app.include_router(router) - - uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/memos/server.py b/memos/server.py index 13f2e99..89a3791 100644 --- a/memos/server.py +++ b/memos/server.py @@ -751,12 +751,6 @@ def run_server(): ocr_main.init_plugin(settings.ocr) app.include_router(ocr_main.router, prefix="/plugins/ocr") - # Add Embedding plugin router - if settings.embedding.enabled: - print("Embedding plugin is enabled") - embedding_main.init_plugin(settings.embedding) - app.include_router(embedding_main.router, prefix="/plugins/embed") - uvicorn.run( "memos.server:app", host=settings.server_host, # Use the new server_host setting