mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
feat: make embedding a default plugin
This commit is contained in:
parent
57110db73f
commit
69aca0153a
@ -32,9 +32,10 @@ class OCRSettings(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingSettings(BaseModel):
|
class EmbeddingSettings(BaseModel):
|
||||||
|
enabled: bool = True
|
||||||
num_dim: int = 768
|
num_dim: int = 768
|
||||||
ollama_endpoint: str = "http://localhost:11434"
|
endpoint: str = "http://localhost:11434/api/embed"
|
||||||
ollama_model: str = "nextfire/paraphrase-multilingual-minilm"
|
model: str = "jinaai/jina-embeddings-v2-base-zh"
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
@ -46,14 +46,19 @@ def parse_date_fields(entity):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings(texts: List[str]) -> List[List[float]]:
|
async def get_embeddings(texts: List[str]) -> List[List[float]]:
|
||||||
print(f"Getting embeddings for {len(texts)} texts")
|
print(f"Getting embeddings for {len(texts)} texts")
|
||||||
ollama_endpoint = settings.embedding.ollama_endpoint
|
|
||||||
ollama_model = settings.embedding.ollama_model
|
if settings.embedding.enabled:
|
||||||
with httpx.Client() as client:
|
endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed"
|
||||||
response = client.post(
|
else:
|
||||||
f"{ollama_endpoint}/api/embed",
|
endpoint = settings.embedding.endpoint
|
||||||
json={"model": ollama_model, "input": texts},
|
|
||||||
|
model = settings.embedding.model
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
endpoint,
|
||||||
|
json={"model": model, "input": texts},
|
||||||
timeout=30
|
timeout=30
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries):
|
|||||||
return metadata_text
|
return metadata_text
|
||||||
|
|
||||||
|
|
||||||
def bulk_upsert(client, entities):
|
async def bulk_upsert(client, entities):
|
||||||
documents = []
|
documents = []
|
||||||
metadata_texts = []
|
metadata_texts = []
|
||||||
entities_with_metadata = []
|
entities_with_metadata = []
|
||||||
@ -142,7 +147,7 @@ def bulk_upsert(client, entities):
|
|||||||
).model_dump(mode="json")
|
).model_dump(mode="json")
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = get_embeddings(metadata_texts)
|
embeddings = await get_embeddings(metadata_texts)
|
||||||
for doc, embedding, entity in zip(documents, embeddings, entities):
|
for doc, embedding, entity in zip(documents, embeddings, entities):
|
||||||
if entity in entities_with_metadata:
|
if entity in entities_with_metadata:
|
||||||
doc["embedding"] = embedding
|
doc["embedding"] = embedding
|
||||||
@ -259,7 +264,7 @@ def list_all_entities(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def search_entities(
|
async def search_entities(
|
||||||
client,
|
client,
|
||||||
q: str,
|
q: str,
|
||||||
library_ids: List[int] = None,
|
library_ids: List[int] = None,
|
||||||
@ -287,7 +292,7 @@ def search_entities(
|
|||||||
filter_by_str = " && ".join(filter_by) if filter_by else ""
|
filter_by_str = " && ".join(filter_by) if filter_by else ""
|
||||||
|
|
||||||
# Convert q to embedding using get_embeddings and take the first embedding
|
# Convert q to embedding using get_embeddings and take the first embedding
|
||||||
embedding = get_embeddings([q])[0]
|
embedding = (await get_embeddings([q]))[0]
|
||||||
|
|
||||||
common_search_params = {
|
common_search_params = {
|
||||||
"collection": TYPESENSE_COLLECTION_NAME,
|
"collection": TYPESENSE_COLLECTION_NAME,
|
||||||
|
131
memos/plugins/embedding/main.py
Normal file
131
memos/plugins/embedding/main.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
import logging
|
||||||
|
import uvicorn
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
PLUGIN_NAME = "embedding"
|
||||||
|
|
||||||
|
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
||||||
|
|
||||||
|
# Global variables
|
||||||
|
enabled = False
|
||||||
|
model = None
|
||||||
|
num_dim = None
|
||||||
|
endpoint = None
|
||||||
|
model_name = None
|
||||||
|
device = None
|
||||||
|
|
||||||
|
# Configure logger
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def init_embedding_model():
|
||||||
|
global model, device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
model = SentenceTransformer(model_name, trust_remote_code=True)
|
||||||
|
model.to(device)
|
||||||
|
logger.info(f"Embedding model initialized on device: {device}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
|
||||||
|
embeddings = model.encode(input_texts, convert_to_tensor=True)
|
||||||
|
embeddings = embeddings.cpu().numpy()
|
||||||
|
# Normalize embeddings
|
||||||
|
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
return embeddings.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
input: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(BaseModel):
|
||||||
|
embeddings: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def read_root():
|
||||||
|
return {"healthy": True, "enabled": enabled}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", include_in_schema=False)
|
||||||
|
@router.post("/", response_model=EmbeddingResponse)
|
||||||
|
async def embed(request: EmbeddingRequest):
|
||||||
|
try:
|
||||||
|
if not request.input:
|
||||||
|
return EmbeddingResponse(embeddings=[])
|
||||||
|
|
||||||
|
# Run the embedding generation in a separate thread to avoid blocking
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
embeddings = await loop.run_in_executor(None, generate_embeddings, request.input)
|
||||||
|
return EmbeddingResponse(embeddings=embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Error generating embeddings: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_plugin(config):
|
||||||
|
global enabled, num_dim, endpoint, model_name
|
||||||
|
enabled = config.enabled
|
||||||
|
num_dim = config.num_dim
|
||||||
|
endpoint = config.endpoint
|
||||||
|
model_name = config.model
|
||||||
|
|
||||||
|
if enabled:
|
||||||
|
init_embedding_model()
|
||||||
|
|
||||||
|
logger.info("Embedding plugin initialized")
|
||||||
|
logger.info(f"Enabled: {enabled}")
|
||||||
|
logger.info(f"Number of dimensions: {num_dim}")
|
||||||
|
logger.info(f"Endpoint: {endpoint}")
|
||||||
|
logger.info(f"Model: {model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Embedding Plugin Configuration")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-dim", type=int, default=768, help="Number of embedding dimensions"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="jinaai/jina-embeddings-v2-base-zh",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=8000, help="Port to run the server on"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.enabled = True
|
||||||
|
self.num_dim = args.num_dim
|
||||||
|
self.endpoint = "what ever"
|
||||||
|
self.model = args.model
|
||||||
|
|
||||||
|
init_plugin(Config(args))
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
@ -23,6 +23,7 @@ import typesense
|
|||||||
from .config import get_database_path, settings
|
from .config import get_database_path, settings
|
||||||
from memos.plugins.vlm import main as vlm_main
|
from memos.plugins.vlm import main as vlm_main
|
||||||
from memos.plugins.ocr import main as ocr_main
|
from memos.plugins.ocr import main as ocr_main
|
||||||
|
from memos.plugins.embedding import main as embedding_main
|
||||||
from . import crud
|
from . import crud
|
||||||
from . import indexing
|
from . import indexing
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
@ -84,18 +85,6 @@ app.mount(
|
|||||||
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
|
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add VLM plugin router
|
|
||||||
if settings.vlm.enabled:
|
|
||||||
print("VLM plugin is enabled")
|
|
||||||
vlm_main.init_plugin(settings.vlm)
|
|
||||||
app.include_router(vlm_main.router, prefix="/plugins/vlm")
|
|
||||||
|
|
||||||
# Add OCR plugin router
|
|
||||||
if settings.ocr.enabled:
|
|
||||||
print("OCR plugin is enabled")
|
|
||||||
ocr_main.init_plugin(settings.ocr)
|
|
||||||
app.include_router(ocr_main.router, prefix="/plugins/ocr")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/favicon.png", response_class=FileResponse)
|
@app.get("/favicon.png", response_class=FileResponse)
|
||||||
async def favicon_png():
|
async def favicon_png():
|
||||||
@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
indexing.bulk_upsert(client, entities)
|
await indexing.bulk_upsert(client, entities)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder(
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/search", response_model=SearchResult, tags=["search"])
|
@app.get("/search", response_model=SearchResult, tags=["search"])
|
||||||
async def search_entities(
|
async def search_entities_route(
|
||||||
q: str,
|
q: str,
|
||||||
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
|
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
|
||||||
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
|
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
|
||||||
@ -502,7 +491,7 @@ async def search_entities(
|
|||||||
[date.strip() for date in created_dates.split(",")] if created_dates else None
|
[date.strip() for date in created_dates.split(",")] if created_dates else None
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return indexing.search_entities(
|
return await indexing.search_entities(
|
||||||
client,
|
client,
|
||||||
q,
|
q,
|
||||||
library_ids,
|
library_ids,
|
||||||
@ -727,6 +716,24 @@ def run_server():
|
|||||||
print(f"VLM plugin enabled: {settings.vlm}")
|
print(f"VLM plugin enabled: {settings.vlm}")
|
||||||
print(f"OCR plugin enabled: {settings.ocr}")
|
print(f"OCR plugin enabled: {settings.ocr}")
|
||||||
|
|
||||||
|
# Add VLM plugin router
|
||||||
|
if settings.vlm.enabled:
|
||||||
|
print("VLM plugin is enabled")
|
||||||
|
vlm_main.init_plugin(settings.vlm)
|
||||||
|
app.include_router(vlm_main.router, prefix="/plugins/vlm")
|
||||||
|
|
||||||
|
# Add OCR plugin router
|
||||||
|
if settings.ocr.enabled:
|
||||||
|
print("OCR plugin is enabled")
|
||||||
|
ocr_main.init_plugin(settings.ocr)
|
||||||
|
app.include_router(ocr_main.router, prefix="/plugins/ocr")
|
||||||
|
|
||||||
|
# Add Embedding plugin router
|
||||||
|
if settings.embedding.enabled:
|
||||||
|
print("Embedding plugin is enabled")
|
||||||
|
embedding_main.init_plugin(settings.embedding)
|
||||||
|
app.include_router(embedding_main.router, prefix="/plugins/embed")
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"memos.server:app",
|
"memos.server:app",
|
||||||
host=settings.server_host, # Use the new server_host setting
|
host=settings.server_host, # Use the new server_host setting
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import httpx
|
import httpx
|
||||||
import torch
|
import torch
|
||||||
@ -34,27 +33,6 @@ torch_dtype = (
|
|||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
|
||||||
def init_embedding_model():
|
|
||||||
model = SentenceTransformer(
|
|
||||||
"jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
embedding_model = init_embedding_model()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
|
|
||||||
embeddings = embedding_model.encode(input_texts, convert_to_tensor=True)
|
|
||||||
embeddings = embeddings.cpu().numpy()
|
|
||||||
# normalized embeddings
|
|
||||||
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
|
||||||
norms[norms == 0] = 1
|
|
||||||
embeddings = embeddings / norms
|
|
||||||
return embeddings.tolist()
|
|
||||||
|
|
||||||
|
|
||||||
# Add a configuration option to choose the model
|
# Add a configuration option to choose the model
|
||||||
parser = argparse.ArgumentParser(description="Run the server with specified model")
|
parser = argparse.ArgumentParser(description="Run the server with specified model")
|
||||||
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
|
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
|
||||||
@ -68,7 +46,10 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
|
|||||||
if use_florence_model:
|
if use_florence_model:
|
||||||
# Load Florence-2 model
|
# Load Florence-2 model
|
||||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||||
"microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True
|
"microsoft/Florence-2-base-ft",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
attn_implementation="sdpa",
|
||||||
|
trust_remote_code=True,
|
||||||
).to(device)
|
).to(device)
|
||||||
florence_processor = AutoProcessor.from_pretrained(
|
florence_processor = AutoProcessor.from_pretrained(
|
||||||
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
||||||
@ -175,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
|
||||||
input: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingResponse(BaseModel):
|
|
||||||
embeddings: List[List[float]]
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/embed", response_model=EmbeddingResponse)
|
|
||||||
async def create_embeddings(request: EmbeddingRequest):
|
|
||||||
try:
|
|
||||||
if not request.input:
|
|
||||||
return EmbeddingResponse(embeddings=[])
|
|
||||||
|
|
||||||
embeddings = generate_embeddings(request.input) # 使用新方法
|
|
||||||
return EmbeddingResponse(embeddings=embeddings)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail=f"Error generating embeddings: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Dict[str, Any]]
|
messages: List[Dict[str, Any]]
|
||||||
|
@ -36,6 +36,9 @@ dependencies = [
|
|||||||
"pyobjc; sys_platform == 'darwin'",
|
"pyobjc; sys_platform == 'darwin'",
|
||||||
"pyobjc-core; sys_platform == 'darwin'",
|
"pyobjc-core; sys_platform == 'darwin'",
|
||||||
"pyobjc-framework-Quartz; sys_platform == 'darwin'",
|
"pyobjc-framework-Quartz; sys_platform == 'darwin'",
|
||||||
|
"sentence-transformers",
|
||||||
|
"torch",
|
||||||
|
"numpy",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user