From 8f3120bac82825a0974f9dd06961c2544d9ed32d Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 8 Oct 2024 23:18:02 +0800 Subject: [PATCH] feat(embedding): support remote embedding --- memos/config.py | 1 + memos/crud.py | 17 ++++++++++------- memos/embedding.py | 41 ++++++++++++++++++++++++++++++++++++----- memos/indexing.py | 23 +++++------------------ memos/models.py | 24 +++++++++++++++++------- memos/server.py | 5 ++--- 6 files changed, 71 insertions(+), 40 deletions(-) diff --git a/memos/config.py b/memos/config.py index 8b6700c..fbc7242 100644 --- a/memos/config.py +++ b/memos/config.py @@ -38,6 +38,7 @@ class EmbeddingSettings(BaseModel): endpoint: str = "http://localhost:11434/api/embed" model: str = "jinaai/jina-embeddings-v2-base-zh" use_modelscope: bool = False + use_local: bool = True class TypesenseSettings(BaseModel): diff --git a/memos/crud.py b/memos/crud.py index da87a21..15b5036 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -27,10 +27,11 @@ from .models import ( ) import numpy as np from collections import defaultdict -from .embedding import generate_embeddings +from .embedding import get_embeddings import logging from sqlite_vec import serialize_float32 import time +import asyncio logger = logging.getLogger(__name__) @@ -458,17 +459,19 @@ def full_text_search( return ids -def vec_search( +async def vec_search( query: str, db: Session, limit: int = 200, library_ids: Optional[List[int]] = None, start: Optional[int] = None, - end: Optional[int] = None, + end: Optional[int] = None ) -> List[int]: - query_embedding = generate_embeddings([query])[0] + query_embedding = await get_embeddings([query]) if not query_embedding: return [] + + query_embedding = query_embedding[0] sql_query = """ SELECT entities.id FROM entities @@ -514,7 +517,7 @@ def reciprocal_rank_fusion( return sorted_results -def hybrid_search( +async def hybrid_search( query: str, db: Session, limit: int = 200, @@ -530,7 +533,7 @@ def hybrid_search( logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds") vec_start = time.time() - vec_results = vec_search(query, db, limit, library_ids, start, end) + vec_results = await vec_search(query, db, limit, library_ids, start, end) vec_end = time.time() logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds") @@ -559,4 +562,4 @@ def hybrid_search( total_time = end_time - start_time logger.info(f"Total hybrid search time: {total_time:.4f} seconds") - return result + return result \ No newline at end of file diff --git a/memos/embedding.py b/memos/embedding.py index adb407d..ab7c50d 100644 --- a/memos/embedding.py +++ b/memos/embedding.py @@ -5,6 +5,8 @@ import numpy as np from modelscope import snapshot_download from .config import settings import logging +import httpx +import asyncio # Configure logger logging.basicConfig(level=logging.INFO) @@ -14,6 +16,7 @@ logger = logging.getLogger(__name__) model = None device = None + def init_embedding_model(): global model, device if torch.cuda.is_available(): @@ -34,21 +37,49 @@ def init_embedding_model(): model.to(device) logger.info(f"Embedding model initialized on device: {device}") + def generate_embeddings(texts: List[str]) -> List[List[float]]: global model - + if model is None: init_embedding_model() - + if not texts: return [] embeddings = model.encode(texts, convert_to_tensor=True) embeddings = embeddings.cpu().numpy() - + # Normalize embeddings norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) norms[norms == 0] = 1 embeddings = embeddings / norms - - return embeddings.tolist() \ No newline at end of file + + return embeddings.tolist() + + +async def get_embeddings(texts: List[str]) -> List[List[float]]: + if settings.embedding.use_local: + embeddings = generate_embeddings(texts) + else: + embeddings = await get_remote_embeddings(texts) + + # Round the embedding values to 5 decimal places + return [ + [round(float(x), 5) for x in embedding] + for embedding in embeddings + ] + + +async def get_remote_embeddings(texts: List[str]) -> List[List[float]]: + payload = {"model": settings.embedding.model, "input": texts} + + async with httpx.AsyncClient() as client: + try: + response = await client.post(settings.embedding.endpoint, json=payload) + response.raise_for_status() + result = response.json() + return result["embeddings"] + except httpx.RequestError as e: + logger.error(f"Error fetching embeddings from remote endpoint: {e}") + return [] # Return an empty list instead of raising an exception diff --git a/memos/indexing.py b/memos/indexing.py index f4446a4..9da83fb 100644 --- a/memos/indexing.py +++ b/memos/indexing.py @@ -17,7 +17,8 @@ from .schemas import ( RequestParams, ) from .config import settings, TYPESENSE_COLLECTION_NAME -from .embedding import generate_embeddings +from .embedding import get_embeddings + def convert_metadata_value(metadata: EntityMetadata): if metadata.data_type == MetadataType.JSON_DATA: @@ -46,20 +47,6 @@ def parse_date_fields(entity): } -async def get_embeddings(texts: List[str]) -> List[List[float]]: - print(f"Getting embeddings for {len(texts)} texts") - - try: - embeddings = generate_embeddings(texts) - print("Successfully generated embeddings.") - return [ - [round(float(x), 5) for x in embedding] - for embedding in embeddings - ] - except Exception as e: - raise Exception(f"Failed to generate embeddings: {str(e)}") - - def generate_metadata_text(metadata_entries): # 暂时不使用ocr结果 def process_ocr_result(metadata): @@ -102,7 +89,7 @@ async def bulk_upsert(client, entities): if metadata_text: metadata_texts.append(metadata_text) entities_with_metadata.append(entity) - + documents.append( EntityIndexItem( id=str(entity.id), @@ -151,10 +138,10 @@ async def bulk_upsert(client, entities): ) -def upsert(client, entity): +async def upsert(client, entity): date_fields = parse_date_fields(entity) metadata_text = generate_metadata_text(entity.metadata_entries) - embedding = get_embeddings([metadata_text])[0] + embedding = (await get_embeddings([metadata_text]))[0] entity_data = EntityIndexItem( id=str(entity.id), diff --git a/memos/models.py b/memos/models.py index b0bf4c1..e46e4e5 100644 --- a/memos/models.py +++ b/memos/models.py @@ -25,8 +25,9 @@ import os import sys from pathlib import Path import json -from .embedding import generate_embeddings +from .embedding import get_embeddings from sqlite_vec import serialize_float32 +import asyncio class Base(DeclarativeBase): @@ -307,7 +308,7 @@ def update_entity_last_scan_at_for_metadata(mapper, connection, target): session.commit() -def update_fts_and_vec(mapper, connection, target): +async def update_fts_and_vec(mapper, connection, target): session = Session(bind=connection) # Prepare FTS data @@ -379,10 +380,13 @@ def update_fts_and_vec(mapper, connection, target): metadata_text = "\n".join( [ f"{entry.key}: {entry.value}" - for entry in target.metadata_entries if entry.key != 'ocr_result' + for entry in target.metadata_entries + if entry.key != "ocr_result" ] ) - embeddings = generate_embeddings([metadata_text]) + + # Use the new get_embeddings function + embeddings = await get_embeddings([metadata_text]) if not embeddings: embedding = [] else: @@ -428,6 +432,12 @@ def delete_fts_and_vec(mapper, connection, target): ) -event.listen(EntityModel, "after_insert", update_fts_and_vec) -event.listen(EntityModel, "after_update", update_fts_and_vec) -event.listen(EntityModel, "after_delete", delete_fts_and_vec) \ No newline at end of file +# Update the event listener to use asyncio +def update_fts_and_vec_sync(mapper, connection, target): + asyncio.run(update_fts_and_vec(mapper, connection, target)) + + +# Replace the old event listener with the new sync version +event.listen(EntityModel, "after_insert", update_fts_and_vec_sync) +event.listen(EntityModel, "after_update", update_fts_and_vec_sync) +event.listen(EntityModel, "after_delete", delete_fts_and_vec) diff --git a/memos/server.py b/memos/server.py index 9c2ddcd..b69759a 100644 --- a/memos/server.py +++ b/memos/server.py @@ -24,8 +24,7 @@ import typesense from .config import get_database_path, settings from memos.plugins.vlm import main as vlm_main from memos.plugins.ocr import main as ocr_main -from . import crud -from . import indexing +from . import crud, indexing from .schemas import ( Library, Folder, @@ -767,7 +766,7 @@ async def search_entities_v2( library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None try: - entities = crud.hybrid_search( + entities = await crud.hybrid_search( query=q, db=db, limit=limit, library_ids=library_ids, start=start, end=end )