feat(embedding): support remote embedding

This commit is contained in:
arkohut 2024-10-08 23:18:02 +08:00
parent bd5e7d6a6f
commit 8f3120bac8
6 changed files with 71 additions and 40 deletions

View File

@ -38,6 +38,7 @@ class EmbeddingSettings(BaseModel):
endpoint: str = "http://localhost:11434/api/embed"
model: str = "jinaai/jina-embeddings-v2-base-zh"
use_modelscope: bool = False
use_local: bool = True
class TypesenseSettings(BaseModel):

View File

@ -27,10 +27,11 @@ from .models import (
)
import numpy as np
from collections import defaultdict
from .embedding import generate_embeddings
from .embedding import get_embeddings
import logging
from sqlite_vec import serialize_float32
import time
import asyncio
logger = logging.getLogger(__name__)
@ -458,17 +459,19 @@ def full_text_search(
return ids
def vec_search(
async def vec_search(
query: str,
db: Session,
limit: int = 200,
library_ids: Optional[List[int]] = None,
start: Optional[int] = None,
end: Optional[int] = None,
end: Optional[int] = None
) -> List[int]:
query_embedding = generate_embeddings([query])[0]
query_embedding = await get_embeddings([query])
if not query_embedding:
return []
query_embedding = query_embedding[0]
sql_query = """
SELECT entities.id FROM entities
@ -514,7 +517,7 @@ def reciprocal_rank_fusion(
return sorted_results
def hybrid_search(
async def hybrid_search(
query: str,
db: Session,
limit: int = 200,
@ -530,7 +533,7 @@ def hybrid_search(
logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds")
vec_start = time.time()
vec_results = vec_search(query, db, limit, library_ids, start, end)
vec_results = await vec_search(query, db, limit, library_ids, start, end)
vec_end = time.time()
logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds")
@ -559,4 +562,4 @@ def hybrid_search(
total_time = end_time - start_time
logger.info(f"Total hybrid search time: {total_time:.4f} seconds")
return result
return result

View File

@ -5,6 +5,8 @@ import numpy as np
from modelscope import snapshot_download
from .config import settings
import logging
import httpx
import asyncio
# Configure logger
logging.basicConfig(level=logging.INFO)
@ -14,6 +16,7 @@ logger = logging.getLogger(__name__)
model = None
device = None
def init_embedding_model():
global model, device
if torch.cuda.is_available():
@ -34,21 +37,49 @@ def init_embedding_model():
model.to(device)
logger.info(f"Embedding model initialized on device: {device}")
def generate_embeddings(texts: List[str]) -> List[List[float]]:
global model
if model is None:
init_embedding_model()
if not texts:
return []
embeddings = model.encode(texts, convert_to_tensor=True)
embeddings = embeddings.cpu().numpy()
# Normalize embeddings
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
return embeddings.tolist()
return embeddings.tolist()
async def get_embeddings(texts: List[str]) -> List[List[float]]:
if settings.embedding.use_local:
embeddings = generate_embeddings(texts)
else:
embeddings = await get_remote_embeddings(texts)
# Round the embedding values to 5 decimal places
return [
[round(float(x), 5) for x in embedding]
for embedding in embeddings
]
async def get_remote_embeddings(texts: List[str]) -> List[List[float]]:
payload = {"model": settings.embedding.model, "input": texts}
async with httpx.AsyncClient() as client:
try:
response = await client.post(settings.embedding.endpoint, json=payload)
response.raise_for_status()
result = response.json()
return result["embeddings"]
except httpx.RequestError as e:
logger.error(f"Error fetching embeddings from remote endpoint: {e}")
return [] # Return an empty list instead of raising an exception

View File

@ -17,7 +17,8 @@ from .schemas import (
RequestParams,
)
from .config import settings, TYPESENSE_COLLECTION_NAME
from .embedding import generate_embeddings
from .embedding import get_embeddings
def convert_metadata_value(metadata: EntityMetadata):
if metadata.data_type == MetadataType.JSON_DATA:
@ -46,20 +47,6 @@ def parse_date_fields(entity):
}
async def get_embeddings(texts: List[str]) -> List[List[float]]:
print(f"Getting embeddings for {len(texts)} texts")
try:
embeddings = generate_embeddings(texts)
print("Successfully generated embeddings.")
return [
[round(float(x), 5) for x in embedding]
for embedding in embeddings
]
except Exception as e:
raise Exception(f"Failed to generate embeddings: {str(e)}")
def generate_metadata_text(metadata_entries):
# 暂时不使用ocr结果
def process_ocr_result(metadata):
@ -102,7 +89,7 @@ async def bulk_upsert(client, entities):
if metadata_text:
metadata_texts.append(metadata_text)
entities_with_metadata.append(entity)
documents.append(
EntityIndexItem(
id=str(entity.id),
@ -151,10 +138,10 @@ async def bulk_upsert(client, entities):
)
def upsert(client, entity):
async def upsert(client, entity):
date_fields = parse_date_fields(entity)
metadata_text = generate_metadata_text(entity.metadata_entries)
embedding = get_embeddings([metadata_text])[0]
embedding = (await get_embeddings([metadata_text]))[0]
entity_data = EntityIndexItem(
id=str(entity.id),

View File

@ -25,8 +25,9 @@ import os
import sys
from pathlib import Path
import json
from .embedding import generate_embeddings
from .embedding import get_embeddings
from sqlite_vec import serialize_float32
import asyncio
class Base(DeclarativeBase):
@ -307,7 +308,7 @@ def update_entity_last_scan_at_for_metadata(mapper, connection, target):
session.commit()
def update_fts_and_vec(mapper, connection, target):
async def update_fts_and_vec(mapper, connection, target):
session = Session(bind=connection)
# Prepare FTS data
@ -379,10 +380,13 @@ def update_fts_and_vec(mapper, connection, target):
metadata_text = "\n".join(
[
f"{entry.key}: {entry.value}"
for entry in target.metadata_entries if entry.key != 'ocr_result'
for entry in target.metadata_entries
if entry.key != "ocr_result"
]
)
embeddings = generate_embeddings([metadata_text])
# Use the new get_embeddings function
embeddings = await get_embeddings([metadata_text])
if not embeddings:
embedding = []
else:
@ -428,6 +432,12 @@ def delete_fts_and_vec(mapper, connection, target):
)
event.listen(EntityModel, "after_insert", update_fts_and_vec)
event.listen(EntityModel, "after_update", update_fts_and_vec)
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
# Update the event listener to use asyncio
def update_fts_and_vec_sync(mapper, connection, target):
asyncio.run(update_fts_and_vec(mapper, connection, target))
# Replace the old event listener with the new sync version
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
event.listen(EntityModel, "after_delete", delete_fts_and_vec)

View File

@ -24,8 +24,7 @@ import typesense
from .config import get_database_path, settings
from memos.plugins.vlm import main as vlm_main
from memos.plugins.ocr import main as ocr_main
from . import crud
from . import indexing
from . import crud, indexing
from .schemas import (
Library,
Folder,
@ -767,7 +766,7 @@ async def search_entities_v2(
library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None
try:
entities = crud.hybrid_search(
entities = await crud.hybrid_search(
query=q, db=db, limit=limit, library_ids=library_ids, start=start, end=end
)