mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
853 lines
26 KiB
Python
853 lines
26 KiB
Python
from typing import List, Tuple, Optional
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import func, text
|
|
from .schemas import (
|
|
Library,
|
|
NewLibraryParam,
|
|
Folder,
|
|
NewEntityParam,
|
|
Entity,
|
|
Plugin,
|
|
NewPluginParam,
|
|
UpdateEntityParam,
|
|
NewFoldersParam,
|
|
MetadataSource,
|
|
EntityMetadataParam,
|
|
)
|
|
from .models import (
|
|
LibraryModel,
|
|
FolderModel,
|
|
EntityModel,
|
|
EntityModel,
|
|
PluginModel,
|
|
LibraryPluginModel,
|
|
TagModel,
|
|
EntityMetadataModel,
|
|
EntityTagModel,
|
|
)
|
|
from collections import defaultdict
|
|
from .embedding import get_embeddings
|
|
import logging
|
|
from sqlite_vec import serialize_float32
|
|
import time
|
|
import json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_library_by_id(library_id: int, db: Session) -> Library | None:
|
|
return db.query(LibraryModel).filter(LibraryModel.id == library_id).first()
|
|
|
|
|
|
def create_library(library: NewLibraryParam, db: Session) -> Library:
|
|
db_library = LibraryModel(name=library.name)
|
|
db.add(db_library)
|
|
db.commit()
|
|
db.refresh(db_library)
|
|
|
|
for folder in library.folders:
|
|
db_folder = FolderModel(
|
|
path=str(folder.path),
|
|
library_id=db_library.id,
|
|
last_modified_at=folder.last_modified_at,
|
|
type=folder.type,
|
|
)
|
|
db.add(db_folder)
|
|
|
|
db.commit()
|
|
return Library(
|
|
id=db_library.id,
|
|
name=db_library.name,
|
|
folders=[
|
|
Folder(
|
|
id=db_folder.id,
|
|
path=db_folder.path,
|
|
last_modified_at=db_folder.last_modified_at,
|
|
type=db_folder.type,
|
|
)
|
|
for db_folder in db_library.folders
|
|
],
|
|
plugins=[],
|
|
)
|
|
|
|
|
|
def get_libraries(db: Session) -> List[Library]:
|
|
return db.query(LibraryModel).all()
|
|
|
|
|
|
def get_library_by_name(library_name: str, db: Session) -> Library | None:
|
|
return (
|
|
db.query(LibraryModel)
|
|
.filter(func.lower(LibraryModel.name) == library_name.lower())
|
|
.first()
|
|
)
|
|
|
|
|
|
def add_folders(library_id: int, folders: NewFoldersParam, db: Session) -> Library:
|
|
for folder in folders.folders:
|
|
db_folder = FolderModel(
|
|
path=str(folder.path),
|
|
library_id=library_id,
|
|
last_modified_at=folder.last_modified_at,
|
|
type=folder.type,
|
|
)
|
|
db.add(db_folder)
|
|
db.commit()
|
|
db.refresh(db_folder)
|
|
|
|
db_library = db.query(LibraryModel).filter(LibraryModel.id == library_id).first()
|
|
return Library(**db_library.__dict__)
|
|
|
|
|
|
def create_entity(
|
|
library_id: int,
|
|
entity: NewEntityParam,
|
|
db: Session,
|
|
) -> Entity:
|
|
tags = entity.tags
|
|
metadata_entries = entity.metadata_entries
|
|
|
|
# Remove tags and metadata_entries from entity
|
|
entity.tags = None
|
|
entity.metadata_entries = None
|
|
|
|
db_entity = EntityModel(
|
|
**entity.model_dump(exclude_none=True), library_id=library_id
|
|
)
|
|
db.add(db_entity)
|
|
db.commit()
|
|
db.refresh(db_entity)
|
|
|
|
# Handle tags separately
|
|
if tags:
|
|
for tag_name in tags:
|
|
tag = db.query(TagModel).filter(TagModel.name == tag_name).first()
|
|
if not tag:
|
|
tag = TagModel(name=tag_name)
|
|
db.add(tag)
|
|
db.commit()
|
|
db.refresh(tag)
|
|
entity_tag = EntityTagModel(
|
|
entity_id=db_entity.id,
|
|
tag_id=tag.id,
|
|
source=MetadataSource.PLUGIN_GENERATED,
|
|
)
|
|
db.add(entity_tag)
|
|
db.commit()
|
|
|
|
# Handle attrs separately
|
|
if metadata_entries:
|
|
for attr in metadata_entries:
|
|
entity_metadata = EntityMetadataModel(
|
|
entity_id=db_entity.id,
|
|
key=attr.key,
|
|
value=attr.value,
|
|
source=attr.source,
|
|
source_type=MetadataSource.PLUGIN_GENERATED if attr.source else None,
|
|
data_type=attr.data_type,
|
|
)
|
|
db.add(entity_metadata)
|
|
db.commit()
|
|
db.refresh(db_entity)
|
|
|
|
return Entity(**db_entity.__dict__)
|
|
|
|
|
|
def get_entity_by_id(entity_id: int, db: Session) -> Entity | None:
|
|
return db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
|
|
|
|
|
def get_entities_of_folder(
|
|
library_id: int, folder_id: int, db: Session, limit: int = 10, offset: int = 0
|
|
) -> Tuple[List[Entity], int]:
|
|
folder = (
|
|
db.query(FolderModel)
|
|
.filter(FolderModel.id == folder_id, FolderModel.library_id == library_id)
|
|
.first()
|
|
)
|
|
if folder is None:
|
|
return [], 0
|
|
|
|
query = db.query(EntityModel).filter(EntityModel.folder_id == folder_id)
|
|
|
|
total_count = query.count()
|
|
|
|
entities = query.limit(limit).offset(offset).all()
|
|
|
|
return entities, total_count
|
|
|
|
|
|
def get_entity_by_filepath(filepath: str, db: Session) -> Entity | None:
|
|
return db.query(EntityModel).filter(EntityModel.filepath == filepath).first()
|
|
|
|
|
|
def get_entities_by_filepaths(filepaths: List[str], db: Session) -> List[Entity]:
|
|
return db.query(EntityModel).filter(EntityModel.filepath.in_(filepaths)).all()
|
|
|
|
|
|
def remove_entity(entity_id: int, db: Session):
|
|
entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
|
if entity:
|
|
# Delete the entity from FTS and vec tables first
|
|
db.execute(text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity_id})
|
|
db.execute(
|
|
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity_id}
|
|
)
|
|
|
|
# Then delete the entity itself
|
|
db.delete(entity)
|
|
db.commit()
|
|
else:
|
|
raise ValueError(f"Entity with id {entity_id} not found")
|
|
|
|
|
|
def create_plugin(newPlugin: NewPluginParam, db: Session) -> Plugin:
|
|
db_plugin = PluginModel(**newPlugin.model_dump(mode="json"))
|
|
db.add(db_plugin)
|
|
db.commit()
|
|
db.refresh(db_plugin)
|
|
return db_plugin
|
|
|
|
|
|
def get_plugins(db: Session) -> List[Plugin]:
|
|
return db.query(PluginModel).all()
|
|
|
|
|
|
def get_plugin_by_name(plugin_name: str, db: Session) -> Plugin | None:
|
|
return (
|
|
db.query(PluginModel)
|
|
.filter(func.lower(PluginModel.name) == plugin_name.lower())
|
|
.first()
|
|
)
|
|
|
|
|
|
def add_plugin_to_library(library_id: int, plugin_id: int, db: Session):
|
|
library_plugin = LibraryPluginModel(library_id=library_id, plugin_id=plugin_id)
|
|
db.add(library_plugin)
|
|
db.commit()
|
|
db.refresh(library_plugin)
|
|
|
|
|
|
def find_entity_by_id(entity_id: int, db: Session) -> Entity | None:
|
|
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
|
if db_entity is None:
|
|
return None
|
|
return Entity(**db_entity.__dict__)
|
|
|
|
|
|
def find_entities_by_ids(entity_ids: List[int], db: Session) -> List[Entity]:
|
|
db_entities = db.query(EntityModel).filter(EntityModel.id.in_(entity_ids)).all()
|
|
return [Entity(**entity.__dict__) for entity in db_entities]
|
|
|
|
|
|
def update_entity(
|
|
entity_id: int,
|
|
updated_entity: UpdateEntityParam,
|
|
db: Session,
|
|
) -> Entity:
|
|
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
|
|
|
if db_entity is None:
|
|
raise ValueError(f"Entity with id {entity_id} not found")
|
|
|
|
# Update the main fields of the entity
|
|
for key, value in updated_entity.model_dump().items():
|
|
if key not in ["tags", "metadata_entries"] and value is not None:
|
|
setattr(db_entity, key, value)
|
|
|
|
# Handle tags separately
|
|
if updated_entity.tags is not None:
|
|
# Clear existing tags
|
|
db.query(EntityTagModel).filter(EntityTagModel.entity_id == entity_id).delete()
|
|
db.commit()
|
|
|
|
for tag_name in updated_entity.tags:
|
|
tag = db.query(TagModel).filter(TagModel.name == tag_name).first()
|
|
if not tag:
|
|
tag = TagModel(name=tag_name)
|
|
db.add(tag)
|
|
db.commit()
|
|
db.refresh(tag)
|
|
entity_tag = EntityTagModel(
|
|
entity_id=db_entity.id,
|
|
tag_id=tag.id,
|
|
source=MetadataSource.PLUGIN_GENERATED,
|
|
)
|
|
db.add(entity_tag)
|
|
db.commit()
|
|
|
|
# Handle attrs separately
|
|
if updated_entity.metadata_entries is not None:
|
|
# Clear existing attrs
|
|
db.query(EntityMetadataModel).filter(
|
|
EntityMetadataModel.entity_id == entity_id
|
|
).delete()
|
|
db.commit()
|
|
|
|
for attr in updated_entity.metadata_entries:
|
|
entity_metadata = EntityMetadataModel(
|
|
entity_id=db_entity.id,
|
|
key=attr.key,
|
|
value=attr.value,
|
|
source=attr.source if attr.source is not None else None,
|
|
source_type=(
|
|
MetadataSource.PLUGIN_GENERATED if attr.source is not None else None
|
|
),
|
|
data_type=attr.data_type,
|
|
)
|
|
db.add(entity_metadata)
|
|
db_entity.metadata_entries.append(entity_metadata)
|
|
|
|
db.commit()
|
|
db.refresh(db_entity)
|
|
|
|
return Entity(**db_entity.__dict__)
|
|
|
|
|
|
def touch_entity(entity_id: int, db: Session) -> bool:
|
|
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
|
if db_entity:
|
|
db_entity.last_scan_at = func.now()
|
|
db.commit()
|
|
db.refresh(db_entity)
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def update_entity_tags(
|
|
entity_id: int,
|
|
tags: List[str],
|
|
db: Session,
|
|
) -> Entity:
|
|
db_entity = get_entity_by_id(entity_id, db)
|
|
if not db_entity:
|
|
raise ValueError(f"Entity with id {entity_id} not found")
|
|
|
|
# Clear existing tags
|
|
db.query(EntityTagModel).filter(EntityTagModel.entity_id == entity_id).delete()
|
|
|
|
for tag_name in tags:
|
|
tag = db.query(TagModel).filter(TagModel.name == tag_name).first()
|
|
if not tag:
|
|
tag = TagModel(name=tag_name)
|
|
db.add(tag)
|
|
db.commit()
|
|
db.refresh(tag)
|
|
entity_tag = EntityTagModel(
|
|
entity_id=db_entity.id,
|
|
tag_id=tag.id,
|
|
source=MetadataSource.PLUGIN_GENERATED,
|
|
)
|
|
db.add(entity_tag)
|
|
|
|
# Update last_scan_at in the same transaction
|
|
db_entity.last_scan_at = func.now()
|
|
|
|
db.commit()
|
|
db.refresh(db_entity)
|
|
|
|
return Entity(**db_entity.__dict__)
|
|
|
|
|
|
def add_new_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
|
|
db_entity = get_entity_by_id(entity_id, db)
|
|
if not db_entity:
|
|
raise ValueError(f"Entity with id {entity_id} not found")
|
|
|
|
existing_tags = set(tag.name for tag in db_entity.tags)
|
|
new_tags = set(tags) - existing_tags
|
|
|
|
for tag_name in new_tags:
|
|
tag = db.query(TagModel).filter(TagModel.name == tag_name).first()
|
|
if not tag:
|
|
tag = TagModel(name=tag_name)
|
|
db.add(tag)
|
|
db.commit()
|
|
db.refresh(tag)
|
|
entity_tag = EntityTagModel(
|
|
entity_id=db_entity.id,
|
|
tag_id=tag.id,
|
|
source=MetadataSource.PLUGIN_GENERATED,
|
|
)
|
|
db.add(entity_tag)
|
|
|
|
# Update last_scan_at in the same transaction
|
|
db_entity.last_scan_at = func.now()
|
|
|
|
db.commit()
|
|
db.refresh(db_entity)
|
|
|
|
return Entity(**db_entity.__dict__)
|
|
|
|
|
|
def update_entity_metadata_entries(
|
|
entity_id: int,
|
|
updated_metadata: List[EntityMetadataParam],
|
|
db: Session,
|
|
) -> Entity:
|
|
db_entity = get_entity_by_id(entity_id, db)
|
|
|
|
existing_metadata_entries = (
|
|
db.query(EntityMetadataModel)
|
|
.filter(EntityMetadataModel.entity_id == db_entity.id)
|
|
.all()
|
|
)
|
|
|
|
existing_metadata_dict = {entry.key: entry for entry in existing_metadata_entries}
|
|
|
|
for metadata in updated_metadata:
|
|
if metadata.key in existing_metadata_dict:
|
|
existing_metadata = existing_metadata_dict[metadata.key]
|
|
existing_metadata.value = metadata.value
|
|
existing_metadata.source = (
|
|
metadata.source
|
|
if metadata.source is not None
|
|
else existing_metadata.source
|
|
)
|
|
existing_metadata.source_type = (
|
|
MetadataSource.PLUGIN_GENERATED
|
|
if metadata.source is not None
|
|
else existing_metadata.source_type
|
|
)
|
|
existing_metadata.data_type = metadata.data_type
|
|
else:
|
|
entity_metadata = EntityMetadataModel(
|
|
entity_id=db_entity.id,
|
|
key=metadata.key,
|
|
value=metadata.value,
|
|
source=metadata.source if metadata.source is not None else None,
|
|
source_type=(
|
|
MetadataSource.PLUGIN_GENERATED
|
|
if metadata.source is not None
|
|
else None
|
|
),
|
|
data_type=metadata.data_type,
|
|
)
|
|
db.add(entity_metadata)
|
|
db_entity.metadata_entries.append(entity_metadata)
|
|
|
|
# Update last_scan_at in the same transaction
|
|
db_entity.last_scan_at = func.now()
|
|
|
|
db.commit()
|
|
db.refresh(db_entity)
|
|
|
|
return Entity(**db_entity.__dict__)
|
|
|
|
|
|
def get_plugin_by_id(plugin_id: int, db: Session) -> Plugin | None:
|
|
return db.query(PluginModel).filter(PluginModel.id == plugin_id).first()
|
|
|
|
|
|
def remove_plugin_from_library(library_id: int, plugin_id: int, db: Session):
|
|
library_plugin = (
|
|
db.query(LibraryPluginModel)
|
|
.filter(
|
|
LibraryPluginModel.library_id == library_id,
|
|
LibraryPluginModel.plugin_id == plugin_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if library_plugin:
|
|
db.delete(library_plugin)
|
|
db.commit()
|
|
else:
|
|
raise ValueError(f"Plugin {plugin_id} not found in library {library_id}")
|
|
|
|
|
|
def and_words(input_string):
|
|
words = input_string.split()
|
|
result = " AND ".join(words)
|
|
return result
|
|
|
|
|
|
def full_text_search(
|
|
query: str,
|
|
db: Session,
|
|
limit: int = 200,
|
|
library_ids: Optional[List[int]] = None,
|
|
start: Optional[int] = None,
|
|
end: Optional[int] = None,
|
|
) -> List[int]:
|
|
and_query = and_words(query)
|
|
|
|
sql_query = """
|
|
SELECT entities.id FROM entities
|
|
JOIN entities_fts ON entities.id = entities_fts.id
|
|
WHERE entities_fts MATCH jieba_query(:query)
|
|
AND entities.file_type_group = 'image'
|
|
"""
|
|
|
|
params = {"query": and_query, "limit": limit}
|
|
|
|
if library_ids:
|
|
library_ids_str = ", ".join(f"'{id}'" for id in library_ids)
|
|
sql_query += f" AND entities.library_id IN ({library_ids_str})"
|
|
if start is not None and end is not None:
|
|
sql_query += " AND strftime('%s', entities.file_created_at, 'utc') BETWEEN :start AND :end"
|
|
params["start"] = str(start)
|
|
params["end"] = str(end)
|
|
|
|
sql_query += (
|
|
" ORDER BY bm25(entities_fts), entities.file_created_at DESC LIMIT :limit"
|
|
)
|
|
|
|
result = db.execute(text(sql_query), params).fetchall()
|
|
|
|
logger.info(f"Full-text search sql: {sql_query}")
|
|
logger.info(f"Full-text search params: {params}")
|
|
ids = [row[0] for row in result]
|
|
logger.info(f"Full-text search results: {ids}")
|
|
return ids
|
|
|
|
|
|
def vec_search(
|
|
query: str,
|
|
db: Session,
|
|
limit: int = 200,
|
|
library_ids: Optional[List[int]] = None,
|
|
start: Optional[int] = None,
|
|
end: Optional[int] = None,
|
|
) -> List[int]:
|
|
query_embedding = get_embeddings([query])
|
|
if not query_embedding:
|
|
return []
|
|
|
|
query_embedding = query_embedding[0]
|
|
|
|
sql_query = """
|
|
SELECT entities.id FROM entities
|
|
JOIN entities_vec ON entities.id = entities_vec.rowid
|
|
WHERE entities_vec.embedding MATCH :embedding
|
|
AND entities.file_type_group = 'image'
|
|
"""
|
|
|
|
params = {"embedding": serialize_float32(query_embedding), "limit": limit}
|
|
|
|
if library_ids:
|
|
library_ids_str = ", ".join(f"'{id}'" for id in library_ids)
|
|
sql_query += f" AND entities.library_id IN ({library_ids_str})"
|
|
|
|
if start is not None and end is not None:
|
|
sql_query += " AND strftime('%s', entities.file_created_at, 'utc') BETWEEN :start AND :end"
|
|
params["start"] = str(start)
|
|
params["end"] = str(end)
|
|
|
|
sql_query += " AND K = :limit ORDER BY distance, entities.file_created_at DESC"
|
|
|
|
result = db.execute(text(sql_query), params).fetchall()
|
|
|
|
ids = [row[0] for row in result]
|
|
logger.info(f"Vector search results: {ids}")
|
|
return ids
|
|
|
|
|
|
def reciprocal_rank_fusion(
|
|
fts_results: List[int], vec_results: List[int], k: int = 60
|
|
) -> List[Tuple[int, float]]:
|
|
rank_dict = defaultdict(float)
|
|
|
|
for rank, result_id in enumerate(fts_results):
|
|
rank_dict[result_id] += 1 / (k + rank + 1)
|
|
|
|
for rank, result_id in enumerate(vec_results):
|
|
rank_dict[result_id] += 1 / (k + rank + 1)
|
|
|
|
sorted_results = sorted(rank_dict.items(), key=lambda x: x[1], reverse=True)
|
|
return sorted_results
|
|
|
|
|
|
def hybrid_search(
|
|
query: str,
|
|
db: Session,
|
|
limit: int = 200,
|
|
library_ids: Optional[List[int]] = None,
|
|
start: Optional[int] = None,
|
|
end: Optional[int] = None,
|
|
) -> List[Entity]:
|
|
start_time = time.time()
|
|
|
|
fts_start = time.time()
|
|
fts_results = full_text_search(query, db, limit, library_ids, start, end)
|
|
fts_end = time.time()
|
|
logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds")
|
|
|
|
vec_start = time.time()
|
|
vec_results = vec_search(query, db, limit, library_ids, start, end)
|
|
vec_end = time.time()
|
|
logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds")
|
|
|
|
fusion_start = time.time()
|
|
combined_results = reciprocal_rank_fusion(fts_results, vec_results)
|
|
fusion_end = time.time()
|
|
logger.info(f"Reciprocal rank fusion took {fusion_end - fusion_start:.4f} seconds")
|
|
|
|
sorted_ids = [id for id, _ in combined_results][:limit]
|
|
logger.info(f"Hybrid search results (sorted IDs): {sorted_ids}")
|
|
|
|
entities_start = time.time()
|
|
entities = find_entities_by_ids(sorted_ids, db)
|
|
entities_end = time.time()
|
|
logger.info(
|
|
f"Finding entities by IDs took {entities_end - entities_start:.4f} seconds"
|
|
)
|
|
|
|
# Create a dictionary mapping entity IDs to entities
|
|
entity_dict = {entity.id: entity for entity in entities}
|
|
|
|
# Return entities in the order of sorted_ids
|
|
result = [entity_dict[id] for id in sorted_ids if id in entity_dict]
|
|
|
|
end_time = time.time()
|
|
total_time = end_time - start_time
|
|
logger.info(f"Total hybrid search time: {total_time:.4f} seconds")
|
|
|
|
return result
|
|
|
|
|
|
def list_entities(
|
|
db: Session,
|
|
limit: int = 200,
|
|
library_ids: Optional[List[int]] = None,
|
|
start: Optional[int] = None,
|
|
end: Optional[int] = None,
|
|
) -> List[Entity]:
|
|
query = db.query(EntityModel).filter(EntityModel.file_type_group == "image")
|
|
|
|
if library_ids:
|
|
query = query.filter(EntityModel.library_id.in_(library_ids))
|
|
|
|
if start is not None and end is not None:
|
|
query = query.filter(
|
|
func.strftime("%s", EntityModel.file_created_at, "utc").between(
|
|
str(start), str(end)
|
|
)
|
|
)
|
|
|
|
entities = query.order_by(EntityModel.file_created_at.desc()).limit(limit).all()
|
|
|
|
return [Entity(**entity.__dict__) for entity in entities]
|
|
|
|
|
|
def get_entity_context(
|
|
db: Session, library_id: int, entity_id: int, prev: int = 0, next: int = 0
|
|
) -> Tuple[List[Entity], List[Entity]]:
|
|
"""
|
|
Get the context (previous and next entities) for a given entity.
|
|
Returns a tuple of (previous_entities, next_entities).
|
|
"""
|
|
# First get the target entity to get its timestamp
|
|
target_entity = (
|
|
db.query(EntityModel)
|
|
.filter(
|
|
EntityModel.id == entity_id,
|
|
EntityModel.library_id == library_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not target_entity:
|
|
return [], []
|
|
|
|
# Get previous entities
|
|
prev_entities = []
|
|
if prev > 0:
|
|
prev_entities = (
|
|
db.query(EntityModel)
|
|
.filter(
|
|
EntityModel.library_id == library_id,
|
|
EntityModel.file_created_at < target_entity.file_created_at,
|
|
)
|
|
.order_by(EntityModel.file_created_at.desc())
|
|
.limit(prev)
|
|
.all()
|
|
)
|
|
# Reverse the list to get chronological order and convert to Entity models
|
|
prev_entities = [Entity(**entity.__dict__) for entity in prev_entities][::-1]
|
|
|
|
# Get next entities
|
|
next_entities = []
|
|
if next > 0:
|
|
next_entities = (
|
|
db.query(EntityModel)
|
|
.filter(
|
|
EntityModel.library_id == library_id,
|
|
EntityModel.file_created_at > target_entity.file_created_at,
|
|
)
|
|
.order_by(EntityModel.file_created_at.asc())
|
|
.limit(next)
|
|
.all()
|
|
)
|
|
# Convert to Entity models
|
|
next_entities = [Entity(**entity.__dict__) for entity in next_entities]
|
|
|
|
return prev_entities, next_entities
|
|
|
|
|
|
def process_ocr_result(value, max_length=4096):
|
|
try:
|
|
ocr_data = json.loads(value)
|
|
if isinstance(ocr_data, list) and all(
|
|
isinstance(item, dict)
|
|
and "dt_boxes" in item
|
|
and "rec_txt" in item
|
|
and "score" in item
|
|
for item in ocr_data
|
|
):
|
|
return " ".join(item["rec_txt"] for item in ocr_data[:max_length])
|
|
else:
|
|
return json.dumps(ocr_data, indent=2)
|
|
except json.JSONDecodeError:
|
|
return value
|
|
|
|
|
|
def prepare_fts_data(entity: EntityModel) -> tuple[str, str]:
|
|
tags = ", ".join([tag.name for tag in entity.tags])
|
|
fts_metadata = "\n".join(
|
|
[
|
|
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
|
|
for entry in entity.metadata_entries
|
|
]
|
|
)
|
|
return tags, fts_metadata
|
|
|
|
|
|
def prepare_vec_data(entity: EntityModel) -> str:
|
|
vec_metadata = "\n".join(
|
|
[
|
|
f"{entry.key}: {entry.value}"
|
|
for entry in entity.metadata_entries
|
|
if entry.key != "ocr_result"
|
|
]
|
|
)
|
|
ocr_result = next(
|
|
(entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"),
|
|
"",
|
|
)
|
|
vec_metadata += f"\nocr_result: {process_ocr_result(ocr_result, max_length=128)}"
|
|
return vec_metadata
|
|
|
|
|
|
def update_entity_index(entity: EntityModel, db: Session):
|
|
"""Update both FTS and vector indexes for an entity"""
|
|
try:
|
|
# Update FTS index
|
|
tags, fts_metadata = prepare_fts_data(entity)
|
|
db.execute(
|
|
text(
|
|
"""
|
|
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
|
|
VALUES(:id, :filepath, :tags, :metadata)
|
|
"""
|
|
),
|
|
{
|
|
"id": entity.id,
|
|
"filepath": entity.filepath,
|
|
"tags": tags,
|
|
"metadata": fts_metadata,
|
|
},
|
|
)
|
|
|
|
# Update vector index
|
|
vec_metadata = prepare_vec_data(entity)
|
|
embeddings = get_embeddings([vec_metadata])
|
|
|
|
if embeddings and embeddings[0]:
|
|
db.execute(
|
|
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id}
|
|
)
|
|
db.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO entities_vec (rowid, embedding)
|
|
VALUES (:id, :embedding)
|
|
"""
|
|
),
|
|
{
|
|
"id": entity.id,
|
|
"embedding": serialize_float32(embeddings[0]),
|
|
},
|
|
)
|
|
|
|
db.commit()
|
|
except Exception as e:
|
|
logger.error(f"Error updating indexes for entity {entity.id}: {e}")
|
|
db.rollback()
|
|
raise
|
|
|
|
|
|
def batch_update_entity_indices(entity_ids: List[int], db: Session):
|
|
"""Batch update both FTS and vector indexes for multiple entities"""
|
|
try:
|
|
# 获取实体
|
|
entities = db.query(EntityModel).filter(EntityModel.id.in_(entity_ids)).all()
|
|
found_ids = {entity.id for entity in entities}
|
|
|
|
# 检查是否所有请求的实体都找到了
|
|
missing_ids = set(entity_ids) - found_ids
|
|
if missing_ids:
|
|
raise ValueError(f"Entities not found: {missing_ids}")
|
|
|
|
# Prepare FTS data for all entities
|
|
fts_data = []
|
|
vec_metadata_list = []
|
|
|
|
for entity in entities:
|
|
# Prepare FTS data
|
|
tags, fts_metadata = prepare_fts_data(entity)
|
|
fts_data.append((entity.id, entity.filepath, tags, fts_metadata))
|
|
|
|
# Prepare vector data
|
|
vec_metadata = prepare_vec_data(entity)
|
|
vec_metadata_list.append(vec_metadata)
|
|
|
|
# Batch update FTS table
|
|
for entity_id, filepath, tags, metadata in fts_data:
|
|
db.execute(
|
|
text(
|
|
"""
|
|
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
|
|
VALUES(:id, :filepath, :tags, :metadata)
|
|
"""
|
|
),
|
|
{
|
|
"id": entity_id,
|
|
"filepath": filepath,
|
|
"tags": tags,
|
|
"metadata": metadata,
|
|
},
|
|
)
|
|
|
|
# Batch get embeddings
|
|
embeddings = get_embeddings(vec_metadata_list)
|
|
|
|
# Batch update vector table
|
|
if embeddings:
|
|
for entity, embedding in zip(entities, embeddings):
|
|
if embedding: # Check if embedding is not empty
|
|
db.execute(
|
|
text("DELETE FROM entities_vec WHERE rowid = :id"),
|
|
{"id": entity.id},
|
|
)
|
|
db.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO entities_vec (rowid, embedding)
|
|
VALUES (:id, :embedding)
|
|
"""
|
|
),
|
|
{
|
|
"id": entity.id,
|
|
"embedding": serialize_float32(embedding),
|
|
},
|
|
)
|
|
|
|
db.commit()
|
|
except Exception as e:
|
|
logger.error(f"Error batch updating indexes: {e}")
|
|
db.rollback()
|
|
raise
|
|
|