mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(embedding): support remote embedding
This commit is contained in:
parent
bd5e7d6a6f
commit
8f3120bac8
@ -38,6 +38,7 @@ class EmbeddingSettings(BaseModel):
|
||||
endpoint: str = "http://localhost:11434/api/embed"
|
||||
model: str = "jinaai/jina-embeddings-v2-base-zh"
|
||||
use_modelscope: bool = False
|
||||
use_local: bool = True
|
||||
|
||||
|
||||
class TypesenseSettings(BaseModel):
|
||||
|
@ -27,10 +27,11 @@ from .models import (
|
||||
)
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from .embedding import generate_embeddings
|
||||
from .embedding import get_embeddings
|
||||
import logging
|
||||
from sqlite_vec import serialize_float32
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -458,17 +459,19 @@ def full_text_search(
|
||||
return ids
|
||||
|
||||
|
||||
def vec_search(
|
||||
async def vec_search(
|
||||
query: str,
|
||||
db: Session,
|
||||
limit: int = 200,
|
||||
library_ids: Optional[List[int]] = None,
|
||||
start: Optional[int] = None,
|
||||
end: Optional[int] = None,
|
||||
end: Optional[int] = None
|
||||
) -> List[int]:
|
||||
query_embedding = generate_embeddings([query])[0]
|
||||
query_embedding = await get_embeddings([query])
|
||||
if not query_embedding:
|
||||
return []
|
||||
|
||||
query_embedding = query_embedding[0]
|
||||
|
||||
sql_query = """
|
||||
SELECT entities.id FROM entities
|
||||
@ -514,7 +517,7 @@ def reciprocal_rank_fusion(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def hybrid_search(
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
db: Session,
|
||||
limit: int = 200,
|
||||
@ -530,7 +533,7 @@ def hybrid_search(
|
||||
logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds")
|
||||
|
||||
vec_start = time.time()
|
||||
vec_results = vec_search(query, db, limit, library_ids, start, end)
|
||||
vec_results = await vec_search(query, db, limit, library_ids, start, end)
|
||||
vec_end = time.time()
|
||||
logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds")
|
||||
|
||||
@ -559,4 +562,4 @@ def hybrid_search(
|
||||
total_time = end_time - start_time
|
||||
logger.info(f"Total hybrid search time: {total_time:.4f} seconds")
|
||||
|
||||
return result
|
||||
return result
|
@ -5,6 +5,8 @@ import numpy as np
|
||||
from modelscope import snapshot_download
|
||||
from .config import settings
|
||||
import logging
|
||||
import httpx
|
||||
import asyncio
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -14,6 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
model = None
|
||||
device = None
|
||||
|
||||
|
||||
def init_embedding_model():
|
||||
global model, device
|
||||
if torch.cuda.is_available():
|
||||
@ -34,21 +37,49 @@ def init_embedding_model():
|
||||
model.to(device)
|
||||
logger.info(f"Embedding model initialized on device: {device}")
|
||||
|
||||
|
||||
def generate_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
global model
|
||||
|
||||
|
||||
if model is None:
|
||||
init_embedding_model()
|
||||
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embeddings = model.encode(texts, convert_to_tensor=True)
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
|
||||
|
||||
# Normalize embeddings
|
||||
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
|
||||
return embeddings.tolist()
|
||||
|
||||
return embeddings.tolist()
|
||||
|
||||
|
||||
async def get_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
if settings.embedding.use_local:
|
||||
embeddings = generate_embeddings(texts)
|
||||
else:
|
||||
embeddings = await get_remote_embeddings(texts)
|
||||
|
||||
# Round the embedding values to 5 decimal places
|
||||
return [
|
||||
[round(float(x), 5) for x in embedding]
|
||||
for embedding in embeddings
|
||||
]
|
||||
|
||||
|
||||
async def get_remote_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
payload = {"model": settings.embedding.model, "input": texts}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(settings.embedding.endpoint, json=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result["embeddings"]
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Error fetching embeddings from remote endpoint: {e}")
|
||||
return [] # Return an empty list instead of raising an exception
|
||||
|
@ -17,7 +17,8 @@ from .schemas import (
|
||||
RequestParams,
|
||||
)
|
||||
from .config import settings, TYPESENSE_COLLECTION_NAME
|
||||
from .embedding import generate_embeddings
|
||||
from .embedding import get_embeddings
|
||||
|
||||
|
||||
def convert_metadata_value(metadata: EntityMetadata):
|
||||
if metadata.data_type == MetadataType.JSON_DATA:
|
||||
@ -46,20 +47,6 @@ def parse_date_fields(entity):
|
||||
}
|
||||
|
||||
|
||||
async def get_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
print(f"Getting embeddings for {len(texts)} texts")
|
||||
|
||||
try:
|
||||
embeddings = generate_embeddings(texts)
|
||||
print("Successfully generated embeddings.")
|
||||
return [
|
||||
[round(float(x), 5) for x in embedding]
|
||||
for embedding in embeddings
|
||||
]
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to generate embeddings: {str(e)}")
|
||||
|
||||
|
||||
def generate_metadata_text(metadata_entries):
|
||||
# 暂时不使用ocr结果
|
||||
def process_ocr_result(metadata):
|
||||
@ -102,7 +89,7 @@ async def bulk_upsert(client, entities):
|
||||
if metadata_text:
|
||||
metadata_texts.append(metadata_text)
|
||||
entities_with_metadata.append(entity)
|
||||
|
||||
|
||||
documents.append(
|
||||
EntityIndexItem(
|
||||
id=str(entity.id),
|
||||
@ -151,10 +138,10 @@ async def bulk_upsert(client, entities):
|
||||
)
|
||||
|
||||
|
||||
def upsert(client, entity):
|
||||
async def upsert(client, entity):
|
||||
date_fields = parse_date_fields(entity)
|
||||
metadata_text = generate_metadata_text(entity.metadata_entries)
|
||||
embedding = get_embeddings([metadata_text])[0]
|
||||
embedding = (await get_embeddings([metadata_text]))[0]
|
||||
|
||||
entity_data = EntityIndexItem(
|
||||
id=str(entity.id),
|
||||
|
@ -25,8 +25,9 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
from .embedding import generate_embeddings
|
||||
from .embedding import get_embeddings
|
||||
from sqlite_vec import serialize_float32
|
||||
import asyncio
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
@ -307,7 +308,7 @@ def update_entity_last_scan_at_for_metadata(mapper, connection, target):
|
||||
session.commit()
|
||||
|
||||
|
||||
def update_fts_and_vec(mapper, connection, target):
|
||||
async def update_fts_and_vec(mapper, connection, target):
|
||||
session = Session(bind=connection)
|
||||
|
||||
# Prepare FTS data
|
||||
@ -379,10 +380,13 @@ def update_fts_and_vec(mapper, connection, target):
|
||||
metadata_text = "\n".join(
|
||||
[
|
||||
f"{entry.key}: {entry.value}"
|
||||
for entry in target.metadata_entries if entry.key != 'ocr_result'
|
||||
for entry in target.metadata_entries
|
||||
if entry.key != "ocr_result"
|
||||
]
|
||||
)
|
||||
embeddings = generate_embeddings([metadata_text])
|
||||
|
||||
# Use the new get_embeddings function
|
||||
embeddings = await get_embeddings([metadata_text])
|
||||
if not embeddings:
|
||||
embedding = []
|
||||
else:
|
||||
@ -428,6 +432,12 @@ def delete_fts_and_vec(mapper, connection, target):
|
||||
)
|
||||
|
||||
|
||||
event.listen(EntityModel, "after_insert", update_fts_and_vec)
|
||||
event.listen(EntityModel, "after_update", update_fts_and_vec)
|
||||
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
|
||||
# Update the event listener to use asyncio
|
||||
def update_fts_and_vec_sync(mapper, connection, target):
|
||||
asyncio.run(update_fts_and_vec(mapper, connection, target))
|
||||
|
||||
|
||||
# Replace the old event listener with the new sync version
|
||||
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
|
||||
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
|
||||
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
|
||||
|
@ -24,8 +24,7 @@ import typesense
|
||||
from .config import get_database_path, settings
|
||||
from memos.plugins.vlm import main as vlm_main
|
||||
from memos.plugins.ocr import main as ocr_main
|
||||
from . import crud
|
||||
from . import indexing
|
||||
from . import crud, indexing
|
||||
from .schemas import (
|
||||
Library,
|
||||
Folder,
|
||||
@ -767,7 +766,7 @@ async def search_entities_v2(
|
||||
library_ids = [int(id) for id in library_ids.split(",")] if library_ids else None
|
||||
|
||||
try:
|
||||
entities = crud.hybrid_search(
|
||||
entities = await crud.hybrid_search(
|
||||
query=q, db=db, limit=limit, library_ids=library_ids, start=start, end=end
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user