From 69aca0153af25ff3f245d7b01789e4d53ab9a84a Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:43:17 +0800 Subject: [PATCH] feat: make embedding a default plugin --- memos/config.py | 5 +- memos/indexing.py | 27 ++++--- memos/plugins/embedding/main.py | 131 ++++++++++++++++++++++++++++++++ memos/server.py | 37 +++++---- memos_ml_backends/server.py | 49 +----------- pyproject.toml | 3 + 6 files changed, 179 insertions(+), 73 deletions(-) create mode 100644 memos/plugins/embedding/main.py diff --git a/memos/config.py b/memos/config.py index 79799c1..2c5bde6 100644 --- a/memos/config.py +++ b/memos/config.py @@ -32,9 +32,10 @@ class OCRSettings(BaseModel): class EmbeddingSettings(BaseModel): + enabled: bool = True num_dim: int = 768 - ollama_endpoint: str = "http://localhost:11434" - ollama_model: str = "nextfire/paraphrase-multilingual-minilm" + endpoint: str = "http://localhost:11434/api/embed" + model: str = "jinaai/jina-embeddings-v2-base-zh" class Settings(BaseSettings): diff --git a/memos/indexing.py b/memos/indexing.py index f110273..5f5d3ba 100644 --- a/memos/indexing.py +++ b/memos/indexing.py @@ -46,14 +46,19 @@ def parse_date_fields(entity): } -def get_embeddings(texts: List[str]) -> List[List[float]]: +async def get_embeddings(texts: List[str]) -> List[List[float]]: print(f"Getting embeddings for {len(texts)} texts") - ollama_endpoint = settings.embedding.ollama_endpoint - ollama_model = settings.embedding.ollama_model - with httpx.Client() as client: - response = client.post( - f"{ollama_endpoint}/api/embed", - json={"model": ollama_model, "input": texts}, + + if settings.embedding.enabled: + endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed" + else: + endpoint = settings.embedding.endpoint + + model = settings.embedding.model + async with httpx.AsyncClient() as client: + response = await client.post( + endpoint, + json={"model": model, "input": texts}, timeout=30 ) if response.status_code == 200: @@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries): return metadata_text -def bulk_upsert(client, entities): +async def bulk_upsert(client, entities): documents = [] metadata_texts = [] entities_with_metadata = [] @@ -142,7 +147,7 @@ def bulk_upsert(client, entities): ).model_dump(mode="json") ) - embeddings = get_embeddings(metadata_texts) + embeddings = await get_embeddings(metadata_texts) for doc, embedding, entity in zip(documents, embeddings, entities): if entity in entities_with_metadata: doc["embedding"] = embedding @@ -259,7 +264,7 @@ def list_all_entities( ) -def search_entities( +async def search_entities( client, q: str, library_ids: List[int] = None, @@ -287,7 +292,7 @@ def search_entities( filter_by_str = " && ".join(filter_by) if filter_by else "" # Convert q to embedding using get_embeddings and take the first embedding - embedding = get_embeddings([q])[0] + embedding = (await get_embeddings([q]))[0] common_search_params = { "collection": TYPESENSE_COLLECTION_NAME, diff --git a/memos/plugins/embedding/main.py b/memos/plugins/embedding/main.py new file mode 100644 index 0000000..89e837a --- /dev/null +++ b/memos/plugins/embedding/main.py @@ -0,0 +1,131 @@ +import asyncio +from typing import List +from fastapi import APIRouter, HTTPException +import logging +import uvicorn +from sentence_transformers import SentenceTransformer +import torch +import numpy as np +from pydantic import BaseModel + +PLUGIN_NAME = "embedding" + +router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}}) + +# Global variables +enabled = False +model = None +num_dim = None +endpoint = None +model_name = None +device = None + +# Configure logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def init_embedding_model(): + global model, device + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + model = SentenceTransformer(model_name, trust_remote_code=True) + model.to(device) + logger.info(f"Embedding model initialized on device: {device}") + + +def generate_embeddings(input_texts: List[str]) -> List[List[float]]: + embeddings = model.encode(input_texts, convert_to_tensor=True) + embeddings = embeddings.cpu().numpy() + # Normalize embeddings + norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) + norms[norms == 0] = 1 + embeddings = embeddings / norms + return embeddings.tolist() + + +class EmbeddingRequest(BaseModel): + input: List[str] + + +class EmbeddingResponse(BaseModel): + embeddings: List[List[float]] + + +@router.get("/") +async def read_root(): + return {"healthy": True, "enabled": enabled} + + +@router.post("", include_in_schema=False) +@router.post("/", response_model=EmbeddingResponse) +async def embed(request: EmbeddingRequest): + try: + if not request.input: + return EmbeddingResponse(embeddings=[]) + + # Run the embedding generation in a separate thread to avoid blocking + loop = asyncio.get_event_loop() + embeddings = await loop.run_in_executor(None, generate_embeddings, request.input) + return EmbeddingResponse(embeddings=embeddings) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error generating embeddings: {str(e)}" + ) + + +def init_plugin(config): + global enabled, num_dim, endpoint, model_name + enabled = config.enabled + num_dim = config.num_dim + endpoint = config.endpoint + model_name = config.model + + if enabled: + init_embedding_model() + + logger.info("Embedding plugin initialized") + logger.info(f"Enabled: {enabled}") + logger.info(f"Number of dimensions: {num_dim}") + logger.info(f"Endpoint: {endpoint}") + logger.info(f"Model: {model_name}") + + +if __name__ == "__main__": + import argparse + from fastapi import FastAPI + + parser = argparse.ArgumentParser(description="Embedding Plugin Configuration") + parser.add_argument( + "--num-dim", type=int, default=768, help="Number of embedding dimensions" + ) + parser.add_argument( + "--model", + type=str, + default="jinaai/jina-embeddings-v2-base-zh", + help="Embedding model name", + ) + parser.add_argument( + "--port", type=int, default=8000, help="Port to run the server on" + ) + + args = parser.parse_args() + + class Config: + def __init__(self, args): + self.enabled = True + self.num_dim = args.num_dim + self.endpoint = "what ever" + self.model = args.model + + init_plugin(Config(args)) + + app = FastAPI() + app.include_router(router) + + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/memos/server.py b/memos/server.py index 9702b68..d48e355 100644 --- a/memos/server.py +++ b/memos/server.py @@ -23,6 +23,7 @@ import typesense from .config import get_database_path, settings from memos.plugins.vlm import main as vlm_main from memos.plugins.ocr import main as ocr_main +from memos.plugins.embedding import main as embedding_main from . import crud from . import indexing from .schemas import ( @@ -84,18 +85,6 @@ app.mount( "/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True) ) -# Add VLM plugin router -if settings.vlm.enabled: - print("VLM plugin is enabled") - vlm_main.init_plugin(settings.vlm) - app.include_router(vlm_main.router, prefix="/plugins/vlm") - -# Add OCR plugin router -if settings.ocr.enabled: - print("OCR plugin is enabled") - ocr_main.init_plugin(settings.ocr) - app.include_router(ocr_main.router, prefix="/plugins/ocr") - @app.get("/favicon.png", response_class=FileResponse) async def favicon_png(): @@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense( ) try: - indexing.bulk_upsert(client, entities) + await indexing.bulk_upsert(client, entities) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder( @app.get("/search", response_model=SearchResult, tags=["search"]) -async def search_entities( +async def search_entities_route( q: str, library_ids: str = Query(None, description="Comma-separated list of library IDs"), folder_ids: str = Query(None, description="Comma-separated list of folder IDs"), @@ -502,7 +491,7 @@ async def search_entities( [date.strip() for date in created_dates.split(",")] if created_dates else None ) try: - return indexing.search_entities( + return await indexing.search_entities( client, q, library_ids, @@ -727,6 +716,24 @@ def run_server(): print(f"VLM plugin enabled: {settings.vlm}") print(f"OCR plugin enabled: {settings.ocr}") + # Add VLM plugin router + if settings.vlm.enabled: + print("VLM plugin is enabled") + vlm_main.init_plugin(settings.vlm) + app.include_router(vlm_main.router, prefix="/plugins/vlm") + + # Add OCR plugin router + if settings.ocr.enabled: + print("OCR plugin is enabled") + ocr_main.init_plugin(settings.ocr) + app.include_router(ocr_main.router, prefix="/plugins/ocr") + + # Add Embedding plugin router + if settings.embedding.enabled: + print("Embedding plugin is enabled") + embedding_main.init_plugin(settings.embedding) + app.include_router(embedding_main.router, prefix="/plugins/embed") + uvicorn.run( "memos.server:app", host=settings.server_host, # Use the new server_host setting diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py index 8eb0653..7a60825 100644 --- a/memos_ml_backends/server.py +++ b/memos_ml_backends/server.py @@ -1,7 +1,6 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Dict, Any, Optional -from sentence_transformers import SentenceTransformer import numpy as np import httpx import torch @@ -34,27 +33,6 @@ torch_dtype = ( print(f"Using device: {device}") -def init_embedding_model(): - model = SentenceTransformer( - "jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True - ) - model.to(device) - return model - - -embedding_model = init_embedding_model() - - -def generate_embeddings(input_texts: List[str]) -> List[List[float]]: - embeddings = embedding_model.encode(input_texts, convert_to_tensor=True) - embeddings = embeddings.cpu().numpy() - # normalized embeddings - norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) - norms[norms == 0] = 1 - embeddings = embeddings / norms - return embeddings.tolist() - - # Add a configuration option to choose the model parser = argparse.ArgumentParser(description="Run the server with specified model") parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") @@ -68,7 +46,10 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True if use_florence_model: # Load Florence-2 model florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, ).to(device) florence_processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base-ft", trust_remote_code=True @@ -175,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): app = FastAPI() -class EmbeddingRequest(BaseModel): - input: List[str] - - -class EmbeddingResponse(BaseModel): - embeddings: List[List[float]] - - -@app.post("/api/embed", response_model=EmbeddingResponse) -async def create_embeddings(request: EmbeddingRequest): - try: - if not request.input: - return EmbeddingResponse(embeddings=[]) - - embeddings = generate_embeddings(request.input) # 使用新方法 - return EmbeddingResponse(embeddings=embeddings) - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error generating embeddings: {str(e)}" - ) - - class ChatCompletionRequest(BaseModel): model: str messages: List[Dict[str, Any]] diff --git a/pyproject.toml b/pyproject.toml index d6d509b..c52a9ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ dependencies = [ "pyobjc; sys_platform == 'darwin'", "pyobjc-core; sys_platform == 'darwin'", "pyobjc-framework-Quartz; sys_platform == 'darwin'", + "sentence-transformers", + "torch", + "numpy", ] [project.urls]