mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat: make embedding a default plugin
This commit is contained in:
parent
57110db73f
commit
69aca0153a
@ -32,9 +32,10 @@ class OCRSettings(BaseModel):
|
||||
|
||||
|
||||
class EmbeddingSettings(BaseModel):
|
||||
enabled: bool = True
|
||||
num_dim: int = 768
|
||||
ollama_endpoint: str = "http://localhost:11434"
|
||||
ollama_model: str = "nextfire/paraphrase-multilingual-minilm"
|
||||
endpoint: str = "http://localhost:11434/api/embed"
|
||||
model: str = "jinaai/jina-embeddings-v2-base-zh"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
|
@ -46,14 +46,19 @@ def parse_date_fields(entity):
|
||||
}
|
||||
|
||||
|
||||
def get_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
async def get_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
print(f"Getting embeddings for {len(texts)} texts")
|
||||
ollama_endpoint = settings.embedding.ollama_endpoint
|
||||
ollama_model = settings.embedding.ollama_model
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
f"{ollama_endpoint}/api/embed",
|
||||
json={"model": ollama_model, "input": texts},
|
||||
|
||||
if settings.embedding.enabled:
|
||||
endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed"
|
||||
else:
|
||||
endpoint = settings.embedding.endpoint
|
||||
|
||||
model = settings.embedding.model
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json={"model": model, "input": texts},
|
||||
timeout=30
|
||||
)
|
||||
if response.status_code == 200:
|
||||
@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries):
|
||||
return metadata_text
|
||||
|
||||
|
||||
def bulk_upsert(client, entities):
|
||||
async def bulk_upsert(client, entities):
|
||||
documents = []
|
||||
metadata_texts = []
|
||||
entities_with_metadata = []
|
||||
@ -142,7 +147,7 @@ def bulk_upsert(client, entities):
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
|
||||
embeddings = get_embeddings(metadata_texts)
|
||||
embeddings = await get_embeddings(metadata_texts)
|
||||
for doc, embedding, entity in zip(documents, embeddings, entities):
|
||||
if entity in entities_with_metadata:
|
||||
doc["embedding"] = embedding
|
||||
@ -259,7 +264,7 @@ def list_all_entities(
|
||||
)
|
||||
|
||||
|
||||
def search_entities(
|
||||
async def search_entities(
|
||||
client,
|
||||
q: str,
|
||||
library_ids: List[int] = None,
|
||||
@ -287,7 +292,7 @@ def search_entities(
|
||||
filter_by_str = " && ".join(filter_by) if filter_by else ""
|
||||
|
||||
# Convert q to embedding using get_embeddings and take the first embedding
|
||||
embedding = get_embeddings([q])[0]
|
||||
embedding = (await get_embeddings([q]))[0]
|
||||
|
||||
common_search_params = {
|
||||
"collection": TYPESENSE_COLLECTION_NAME,
|
||||
|
131
memos/plugins/embedding/main.py
Normal file
131
memos/plugins/embedding/main.py
Normal file
@ -0,0 +1,131 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
from fastapi import APIRouter, HTTPException
|
||||
import logging
|
||||
import uvicorn
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import torch
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
PLUGIN_NAME = "embedding"
|
||||
|
||||
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
||||
|
||||
# Global variables
|
||||
enabled = False
|
||||
model = None
|
||||
num_dim = None
|
||||
endpoint = None
|
||||
model_name = None
|
||||
device = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_embedding_model():
|
||||
global model, device
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
model = SentenceTransformer(model_name, trust_remote_code=True)
|
||||
model.to(device)
|
||||
logger.info(f"Embedding model initialized on device: {device}")
|
||||
|
||||
|
||||
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
|
||||
embeddings = model.encode(input_texts, convert_to_tensor=True)
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
# Normalize embeddings
|
||||
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
return embeddings.tolist()
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: List[str]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def read_root():
|
||||
return {"healthy": True, "enabled": enabled}
|
||||
|
||||
|
||||
@router.post("", include_in_schema=False)
|
||||
@router.post("/", response_model=EmbeddingResponse)
|
||||
async def embed(request: EmbeddingRequest):
|
||||
try:
|
||||
if not request.input:
|
||||
return EmbeddingResponse(embeddings=[])
|
||||
|
||||
# Run the embedding generation in a separate thread to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(None, generate_embeddings, request.input)
|
||||
return EmbeddingResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error generating embeddings: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global enabled, num_dim, endpoint, model_name
|
||||
enabled = config.enabled
|
||||
num_dim = config.num_dim
|
||||
endpoint = config.endpoint
|
||||
model_name = config.model
|
||||
|
||||
if enabled:
|
||||
init_embedding_model()
|
||||
|
||||
logger.info("Embedding plugin initialized")
|
||||
logger.info(f"Enabled: {enabled}")
|
||||
logger.info(f"Number of dimensions: {num_dim}")
|
||||
logger.info(f"Endpoint: {endpoint}")
|
||||
logger.info(f"Model: {model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from fastapi import FastAPI
|
||||
|
||||
parser = argparse.ArgumentParser(description="Embedding Plugin Configuration")
|
||||
parser.add_argument(
|
||||
"--num-dim", type=int, default=768, help="Number of embedding dimensions"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="jinaai/jina-embeddings-v2-base-zh",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
class Config:
|
||||
def __init__(self, args):
|
||||
self.enabled = True
|
||||
self.num_dim = args.num_dim
|
||||
self.endpoint = "what ever"
|
||||
self.model = args.model
|
||||
|
||||
init_plugin(Config(args))
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
@ -23,6 +23,7 @@ import typesense
|
||||
from .config import get_database_path, settings
|
||||
from memos.plugins.vlm import main as vlm_main
|
||||
from memos.plugins.ocr import main as ocr_main
|
||||
from memos.plugins.embedding import main as embedding_main
|
||||
from . import crud
|
||||
from . import indexing
|
||||
from .schemas import (
|
||||
@ -84,18 +85,6 @@ app.mount(
|
||||
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
|
||||
)
|
||||
|
||||
# Add VLM plugin router
|
||||
if settings.vlm.enabled:
|
||||
print("VLM plugin is enabled")
|
||||
vlm_main.init_plugin(settings.vlm)
|
||||
app.include_router(vlm_main.router, prefix="/plugins/vlm")
|
||||
|
||||
# Add OCR plugin router
|
||||
if settings.ocr.enabled:
|
||||
print("OCR plugin is enabled")
|
||||
ocr_main.init_plugin(settings.ocr)
|
||||
app.include_router(ocr_main.router, prefix="/plugins/ocr")
|
||||
|
||||
|
||||
@app.get("/favicon.png", response_class=FileResponse)
|
||||
async def favicon_png():
|
||||
@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense(
|
||||
)
|
||||
|
||||
try:
|
||||
indexing.bulk_upsert(client, entities)
|
||||
await indexing.bulk_upsert(client, entities)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder(
|
||||
|
||||
|
||||
@app.get("/search", response_model=SearchResult, tags=["search"])
|
||||
async def search_entities(
|
||||
async def search_entities_route(
|
||||
q: str,
|
||||
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
|
||||
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
|
||||
@ -502,7 +491,7 @@ async def search_entities(
|
||||
[date.strip() for date in created_dates.split(",")] if created_dates else None
|
||||
)
|
||||
try:
|
||||
return indexing.search_entities(
|
||||
return await indexing.search_entities(
|
||||
client,
|
||||
q,
|
||||
library_ids,
|
||||
@ -727,6 +716,24 @@ def run_server():
|
||||
print(f"VLM plugin enabled: {settings.vlm}")
|
||||
print(f"OCR plugin enabled: {settings.ocr}")
|
||||
|
||||
# Add VLM plugin router
|
||||
if settings.vlm.enabled:
|
||||
print("VLM plugin is enabled")
|
||||
vlm_main.init_plugin(settings.vlm)
|
||||
app.include_router(vlm_main.router, prefix="/plugins/vlm")
|
||||
|
||||
# Add OCR plugin router
|
||||
if settings.ocr.enabled:
|
||||
print("OCR plugin is enabled")
|
||||
ocr_main.init_plugin(settings.ocr)
|
||||
app.include_router(ocr_main.router, prefix="/plugins/ocr")
|
||||
|
||||
# Add Embedding plugin router
|
||||
if settings.embedding.enabled:
|
||||
print("Embedding plugin is enabled")
|
||||
embedding_main.init_plugin(settings.embedding)
|
||||
app.include_router(embedding_main.router, prefix="/plugins/embed")
|
||||
|
||||
uvicorn.run(
|
||||
"memos.server:app",
|
||||
host=settings.server_host, # Use the new server_host setting
|
||||
|
@ -1,7 +1,6 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import numpy as np
|
||||
import httpx
|
||||
import torch
|
||||
@ -34,27 +33,6 @@ torch_dtype = (
|
||||
print(f"Using device: {device}")
|
||||
|
||||
|
||||
def init_embedding_model():
|
||||
model = SentenceTransformer(
|
||||
"jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True
|
||||
)
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
embedding_model = init_embedding_model()
|
||||
|
||||
|
||||
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
|
||||
embeddings = embedding_model.encode(input_texts, convert_to_tensor=True)
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
# normalized embeddings
|
||||
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
return embeddings.tolist()
|
||||
|
||||
|
||||
# Add a configuration option to choose the model
|
||||
parser = argparse.ArgumentParser(description="Run the server with specified model")
|
||||
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
|
||||
@ -68,7 +46,10 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
|
||||
if use_florence_model:
|
||||
# Load Florence-2 model
|
||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True
|
||||
"microsoft/Florence-2-base-ft",
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=True,
|
||||
).to(device)
|
||||
florence_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
||||
@ -175,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: List[str]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
@app.post("/api/embed", response_model=EmbeddingResponse)
|
||||
async def create_embeddings(request: EmbeddingRequest):
|
||||
try:
|
||||
if not request.input:
|
||||
return EmbeddingResponse(embeddings=[])
|
||||
|
||||
embeddings = generate_embeddings(request.input) # 使用新方法
|
||||
return EmbeddingResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error generating embeddings: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, Any]]
|
||||
|
@ -36,6 +36,9 @@ dependencies = [
|
||||
"pyobjc; sys_platform == 'darwin'",
|
||||
"pyobjc-core; sys_platform == 'darwin'",
|
||||
"pyobjc-framework-Quartz; sys_platform == 'darwin'",
|
||||
"sentence-transformers",
|
||||
"torch",
|
||||
"numpy",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
Loading…
x
Reference in New Issue
Block a user