feat: make embedding a default plugin

This commit is contained in:
arkohut 2024-09-09 19:43:17 +08:00
parent 57110db73f
commit 69aca0153a
6 changed files with 179 additions and 73 deletions

View File

@ -32,9 +32,10 @@ class OCRSettings(BaseModel):
class EmbeddingSettings(BaseModel): class EmbeddingSettings(BaseModel):
enabled: bool = True
num_dim: int = 768 num_dim: int = 768
ollama_endpoint: str = "http://localhost:11434" endpoint: str = "http://localhost:11434/api/embed"
ollama_model: str = "nextfire/paraphrase-multilingual-minilm" model: str = "jinaai/jina-embeddings-v2-base-zh"
class Settings(BaseSettings): class Settings(BaseSettings):

View File

@ -46,14 +46,19 @@ def parse_date_fields(entity):
} }
def get_embeddings(texts: List[str]) -> List[List[float]]: async def get_embeddings(texts: List[str]) -> List[List[float]]:
print(f"Getting embeddings for {len(texts)} texts") print(f"Getting embeddings for {len(texts)} texts")
ollama_endpoint = settings.embedding.ollama_endpoint
ollama_model = settings.embedding.ollama_model if settings.embedding.enabled:
with httpx.Client() as client: endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed"
response = client.post( else:
f"{ollama_endpoint}/api/embed", endpoint = settings.embedding.endpoint
json={"model": ollama_model, "input": texts},
model = settings.embedding.model
async with httpx.AsyncClient() as client:
response = await client.post(
endpoint,
json={"model": model, "input": texts},
timeout=30 timeout=30
) )
if response.status_code == 200: if response.status_code == 200:
@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries):
return metadata_text return metadata_text
def bulk_upsert(client, entities): async def bulk_upsert(client, entities):
documents = [] documents = []
metadata_texts = [] metadata_texts = []
entities_with_metadata = [] entities_with_metadata = []
@ -142,7 +147,7 @@ def bulk_upsert(client, entities):
).model_dump(mode="json") ).model_dump(mode="json")
) )
embeddings = get_embeddings(metadata_texts) embeddings = await get_embeddings(metadata_texts)
for doc, embedding, entity in zip(documents, embeddings, entities): for doc, embedding, entity in zip(documents, embeddings, entities):
if entity in entities_with_metadata: if entity in entities_with_metadata:
doc["embedding"] = embedding doc["embedding"] = embedding
@ -259,7 +264,7 @@ def list_all_entities(
) )
def search_entities( async def search_entities(
client, client,
q: str, q: str,
library_ids: List[int] = None, library_ids: List[int] = None,
@ -287,7 +292,7 @@ def search_entities(
filter_by_str = " && ".join(filter_by) if filter_by else "" filter_by_str = " && ".join(filter_by) if filter_by else ""
# Convert q to embedding using get_embeddings and take the first embedding # Convert q to embedding using get_embeddings and take the first embedding
embedding = get_embeddings([q])[0] embedding = (await get_embeddings([q]))[0]
common_search_params = { common_search_params = {
"collection": TYPESENSE_COLLECTION_NAME, "collection": TYPESENSE_COLLECTION_NAME,

View File

@ -0,0 +1,131 @@
import asyncio
from typing import List
from fastapi import APIRouter, HTTPException
import logging
import uvicorn
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from pydantic import BaseModel
PLUGIN_NAME = "embedding"
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
# Global variables
enabled = False
model = None
num_dim = None
endpoint = None
model_name = None
device = None
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def init_embedding_model():
global model, device
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
model = SentenceTransformer(model_name, trust_remote_code=True)
model.to(device)
logger.info(f"Embedding model initialized on device: {device}")
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
embeddings = model.encode(input_texts, convert_to_tensor=True)
embeddings = embeddings.cpu().numpy()
# Normalize embeddings
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
return embeddings.tolist()
class EmbeddingRequest(BaseModel):
input: List[str]
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
@router.get("/")
async def read_root():
return {"healthy": True, "enabled": enabled}
@router.post("", include_in_schema=False)
@router.post("/", response_model=EmbeddingResponse)
async def embed(request: EmbeddingRequest):
try:
if not request.input:
return EmbeddingResponse(embeddings=[])
# Run the embedding generation in a separate thread to avoid blocking
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(None, generate_embeddings, request.input)
return EmbeddingResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error generating embeddings: {str(e)}"
)
def init_plugin(config):
global enabled, num_dim, endpoint, model_name
enabled = config.enabled
num_dim = config.num_dim
endpoint = config.endpoint
model_name = config.model
if enabled:
init_embedding_model()
logger.info("Embedding plugin initialized")
logger.info(f"Enabled: {enabled}")
logger.info(f"Number of dimensions: {num_dim}")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Model: {model_name}")
if __name__ == "__main__":
import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="Embedding Plugin Configuration")
parser.add_argument(
"--num-dim", type=int, default=768, help="Number of embedding dimensions"
)
parser.add_argument(
"--model",
type=str,
default="jinaai/jina-embeddings-v2-base-zh",
help="Embedding model name",
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args()
class Config:
def __init__(self, args):
self.enabled = True
self.num_dim = args.num_dim
self.endpoint = "what ever"
self.model = args.model
init_plugin(Config(args))
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -23,6 +23,7 @@ import typesense
from .config import get_database_path, settings from .config import get_database_path, settings
from memos.plugins.vlm import main as vlm_main from memos.plugins.vlm import main as vlm_main
from memos.plugins.ocr import main as ocr_main from memos.plugins.ocr import main as ocr_main
from memos.plugins.embedding import main as embedding_main
from . import crud from . import crud
from . import indexing from . import indexing
from .schemas import ( from .schemas import (
@ -84,18 +85,6 @@ app.mount(
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True) "/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
) )
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix="/plugins/vlm")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix="/plugins/ocr")
@app.get("/favicon.png", response_class=FileResponse) @app.get("/favicon.png", response_class=FileResponse)
async def favicon_png(): async def favicon_png():
@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense(
) )
try: try:
indexing.bulk_upsert(client, entities) await indexing.bulk_upsert(client, entities)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder(
@app.get("/search", response_model=SearchResult, tags=["search"]) @app.get("/search", response_model=SearchResult, tags=["search"])
async def search_entities( async def search_entities_route(
q: str, q: str,
library_ids: str = Query(None, description="Comma-separated list of library IDs"), library_ids: str = Query(None, description="Comma-separated list of library IDs"),
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"), folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
@ -502,7 +491,7 @@ async def search_entities(
[date.strip() for date in created_dates.split(",")] if created_dates else None [date.strip() for date in created_dates.split(",")] if created_dates else None
) )
try: try:
return indexing.search_entities( return await indexing.search_entities(
client, client,
q, q,
library_ids, library_ids,
@ -727,6 +716,24 @@ def run_server():
print(f"VLM plugin enabled: {settings.vlm}") print(f"VLM plugin enabled: {settings.vlm}")
print(f"OCR plugin enabled: {settings.ocr}") print(f"OCR plugin enabled: {settings.ocr}")
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix="/plugins/vlm")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix="/plugins/ocr")
# Add Embedding plugin router
if settings.embedding.enabled:
print("Embedding plugin is enabled")
embedding_main.init_plugin(settings.embedding)
app.include_router(embedding_main.router, prefix="/plugins/embed")
uvicorn.run( uvicorn.run(
"memos.server:app", "memos.server:app",
host=settings.server_host, # Use the new server_host setting host=settings.server_host, # Use the new server_host setting

View File

@ -1,7 +1,6 @@
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer
import numpy as np import numpy as np
import httpx import httpx
import torch import torch
@ -34,27 +33,6 @@ torch_dtype = (
print(f"Using device: {device}") print(f"Using device: {device}")
def init_embedding_model():
model = SentenceTransformer(
"jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True
)
model.to(device)
return model
embedding_model = init_embedding_model()
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
embeddings = embedding_model.encode(input_texts, convert_to_tensor=True)
embeddings = embeddings.cpu().numpy()
# normalized embeddings
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
return embeddings.tolist()
# Add a configuration option to choose the model # Add a configuration option to choose the model
parser = argparse.ArgumentParser(description="Run the server with specified model") parser = argparse.ArgumentParser(description="Run the server with specified model")
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
@ -68,7 +46,10 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
if use_florence_model: if use_florence_model:
# Load Florence-2 model # Load Florence-2 model
florence_model = AutoModelForCausalLM.from_pretrained( florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True "microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device) ).to(device)
florence_processor = AutoProcessor.from_pretrained( florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True "microsoft/Florence-2-base-ft", trust_remote_code=True
@ -175,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
app = FastAPI() app = FastAPI()
class EmbeddingRequest(BaseModel):
input: List[str]
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
@app.post("/api/embed", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
try:
if not request.input:
return EmbeddingResponse(embeddings=[])
embeddings = generate_embeddings(request.input) # 使用新方法
return EmbeddingResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error generating embeddings: {str(e)}"
)
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Dict[str, Any]] messages: List[Dict[str, Any]]

View File

@ -36,6 +36,9 @@ dependencies = [
"pyobjc; sys_platform == 'darwin'", "pyobjc; sys_platform == 'darwin'",
"pyobjc-core; sys_platform == 'darwin'", "pyobjc-core; sys_platform == 'darwin'",
"pyobjc-framework-Quartz; sys_platform == 'darwin'", "pyobjc-framework-Quartz; sys_platform == 'darwin'",
"sentence-transformers",
"torch",
"numpy",
] ]
[project.urls] [project.urls]