feat: concurrency problem solved

This commit is contained in:
arkohut 2024-11-07 18:23:15 +08:00
parent 22c2158a58
commit 7adb08e6d9
8 changed files with 326 additions and 235 deletions

View File

@ -507,7 +507,10 @@ async def add_entity(
post_response = await client.post( post_response = await client.post(
f"{BASE_URL}/libraries/{library_id}/entities", f"{BASE_URL}/libraries/{library_id}/entities",
json=new_entity, json=new_entity,
params={"plugins": plugins} if plugins else {}, params={
"plugins": plugins,
"update_index": "true"
} if plugins else {"update_index": "true"},
timeout=60, timeout=60,
) )
if 200 <= post_response.status_code < 300: if 200 <= post_response.status_code < 300:
@ -546,6 +549,7 @@ async def update_entity(
json=new_entity, json=new_entity,
params={ params={
"trigger_webhooks_flag": "true", "trigger_webhooks_flag": "true",
"update_index": "true",
**({"plugins": plugins} if plugins else {}), **({"plugins": plugins} if plugins else {}),
}, },
timeout=60, timeout=60,
@ -581,6 +585,7 @@ def reindex(
force: bool = typer.Option( force: bool = typer.Option(
False, "--force", help="Force recreate FTS and vector tables before reindexing" False, "--force", help="Force recreate FTS and vector tables before reindexing"
), ),
batch_size: int = typer.Option(1, "--batch-size", "-bs", help="Batch size for processing entities"),
): ):
print(f"Reindexing library {library_id}") print(f"Reindexing library {library_id}")
@ -608,7 +613,7 @@ def reindex(
recreate_fts_and_vec_tables() recreate_fts_and_vec_tables()
print("FTS and vector tables have been recreated.") print("FTS and vector tables have been recreated.")
with httpx.Session() as client: with httpx.Client() as client:
total_entities = 0 total_entities = 0
# Get total entity count for all folders # Get total entity count for all folders
@ -647,56 +652,34 @@ def reindex(
if not entities: if not entities:
break break
# Update last_scan_at for each entity # 收集需要处理的实体 ID
for entity in entities: entity_ids = [
if entity["id"] in scanned_entities: entity["id"]
continue for entity in entities
if entity["id"] not in scanned_entities
]
update_response = client.post( # 按 batch_size 分批处理
f"{BASE_URL}/entities/{entity['id']}/last-scan-at" for i in range(0, len(entity_ids), batch_size):
batch_ids = entity_ids[i:i + batch_size]
if batch_ids:
batch_response = client.post(
f"{BASE_URL}/entities/batch-index",
json={"entity_ids": batch_ids},
timeout=60,
) )
if update_response.status_code != 204: if batch_response.status_code != 204:
print( print(
f"Failed to update last_scan_at for entity {entity['id']}: {update_response.status_code} - {update_response.text}" f"Failed to update batch: {batch_response.status_code} - {batch_response.text}"
) )
else: pbar.update(len(batch_ids))
scanned_entities.add(entity["id"]) scanned_entities.update(batch_ids)
pbar.update(1)
offset += limit offset += limit
print(f"Reindexing completed for library {library_id}") print(f"Reindexing completed for library {library_id}")
async def check_and_index_entity(client, entity_id, entity_last_scan_at):
try:
index_response = client.get(f"{BASE_URL}/entities/{entity_id}/index")
if index_response.status_code == 200:
index_data = index_response.json()
if index_data["last_scan_at"] is None:
return entity_last_scan_at is not None
index_last_scan_at = datetime.fromtimestamp(index_data["last_scan_at"])
entity_last_scan_at = datetime.fromisoformat(entity_last_scan_at)
if index_last_scan_at >= entity_last_scan_at:
return False # Index is up to date, no need to update
return True # Index doesn't exist or needs update
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
return True # Index doesn't exist, need to create
raise # Re-raise other HTTP errors
async def index_batch(client, entity_ids):
index_response = client.post(
f"{BASE_URL}/entities/batch-index",
json=entity_ids,
timeout=60,
)
return index_response
@lib_app.command("sync") @lib_app.command("sync")
def sync( def sync(
library_id: int, library_id: int,
@ -799,7 +782,10 @@ def sync(
update_response = httpx.put( update_response = httpx.put(
f"{BASE_URL}/entities/{existing_entity['id']}", f"{BASE_URL}/entities/{existing_entity['id']}",
json=new_entity, json=new_entity,
params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, params={
"trigger_webhooks_flag": str(not without_webhooks).lower(),
"update_index": "true",
},
timeout=60, timeout=60,
) )
if update_response.status_code == 200: if update_response.status_code == 200:
@ -829,7 +815,10 @@ def sync(
create_response = httpx.post( create_response = httpx.post(
f"{BASE_URL}/libraries/{library_id}/entities", f"{BASE_URL}/libraries/{library_id}/entities",
json=new_entity, json=new_entity,
params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, params={
"trigger_webhooks_flag": str(not without_webhooks).lower(),
"update_index": "true",
},
timeout=60, timeout=60,
) )

View File

@ -191,7 +191,10 @@ def scan_default_library(
def reindex_default_library( def reindex_default_library(
force: bool = typer.Option( force: bool = typer.Option(
False, "--force", help="Force recreate FTS and vector tables before reindexing" False, "--force", help="Force recreate FTS and vector tables before reindexing"
) ),
batch_size: int = typer.Option(
1, "--batch-size", "-bs", help="Batch size for processing files"
),
): ):
""" """
Reindex the default library for memos. Reindex the default library for memos.
@ -215,7 +218,7 @@ def reindex_default_library(
# Reindex the library # Reindex the library
print(f"Reindexing library: {default_library['name']}") print(f"Reindexing library: {default_library['name']}")
reindex(default_library["id"], force=force, folders=None) reindex(default_library["id"], force=force, folders=None, batch_size=batch_size)
@app.command("record") @app.command("record")

View File

@ -25,13 +25,12 @@ from .models import (
EntityMetadataModel, EntityMetadataModel,
EntityTagModel, EntityTagModel,
) )
import numpy as np
from collections import defaultdict from collections import defaultdict
from .embedding import get_embeddings from .embedding import get_embeddings
import logging import logging
from sqlite_vec import serialize_float32 from sqlite_vec import serialize_float32
import time import time
import asyncio import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -100,7 +99,11 @@ def add_folders(library_id: int, folders: NewFoldersParam, db: Session) -> Libra
return Library(**db_library.__dict__) return Library(**db_library.__dict__)
def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity: def create_entity(
library_id: int,
entity: NewEntityParam,
db: Session,
) -> Entity:
tags = entity.tags tags = entity.tags
metadata_entries = entity.metadata_entries metadata_entries = entity.metadata_entries
@ -146,6 +149,7 @@ def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entit
db.add(entity_metadata) db.add(entity_metadata)
db.commit() db.commit()
db.refresh(db_entity) db.refresh(db_entity)
return Entity(**db_entity.__dict__) return Entity(**db_entity.__dict__)
@ -185,13 +189,9 @@ def remove_entity(entity_id: int, db: Session):
entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first() entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
if entity: if entity:
# Delete the entity from FTS and vec tables first # Delete the entity from FTS and vec tables first
db.execute(text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity_id})
db.execute( db.execute(
text("DELETE FROM entities_fts WHERE id = :id"), text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity_id}
{"id": entity_id}
)
db.execute(
text("DELETE FROM entities_vec WHERE rowid = :id"),
{"id": entity_id}
) )
# Then delete the entity itself # Then delete the entity itself
@ -241,7 +241,9 @@ def find_entities_by_ids(entity_ids: List[int], db: Session) -> List[Entity]:
def update_entity( def update_entity(
entity_id: int, updated_entity: UpdateEntityParam, db: Session entity_id: int,
updated_entity: UpdateEntityParam,
db: Session,
) -> Entity: ) -> Entity:
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first() db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
@ -298,6 +300,7 @@ def update_entity(
db.commit() db.commit()
db.refresh(db_entity) db.refresh(db_entity)
return Entity(**db_entity.__dict__) return Entity(**db_entity.__dict__)
@ -312,7 +315,11 @@ def touch_entity(entity_id: int, db: Session) -> bool:
return False return False
def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity: def update_entity_tags(
entity_id: int,
tags: List[str],
db: Session,
) -> Entity:
db_entity = get_entity_by_id(entity_id, db) db_entity = get_entity_by_id(entity_id, db)
if not db_entity: if not db_entity:
raise ValueError(f"Entity with id {entity_id} not found") raise ValueError(f"Entity with id {entity_id} not found")
@ -339,6 +346,7 @@ def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
db.commit() db.commit()
db.refresh(db_entity) db.refresh(db_entity)
return Entity(**db_entity.__dict__) return Entity(**db_entity.__dict__)
@ -369,11 +377,14 @@ def add_new_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
db.commit() db.commit()
db.refresh(db_entity) db.refresh(db_entity)
return Entity(**db_entity.__dict__) return Entity(**db_entity.__dict__)
def update_entity_metadata_entries( def update_entity_metadata_entries(
entity_id: int, updated_metadata: List[EntityMetadataParam], db: Session entity_id: int,
updated_metadata: List[EntityMetadataParam],
db: Session,
) -> Entity: ) -> Entity:
db_entity = get_entity_by_id(entity_id, db) db_entity = get_entity_by_id(entity_id, db)
@ -421,6 +432,7 @@ def update_entity_metadata_entries(
db.commit() db.commit()
db.refresh(db_entity) db.refresh(db_entity)
return Entity(**db_entity.__dict__) return Entity(**db_entity.__dict__)
@ -478,7 +490,9 @@ def full_text_search(
params["start"] = str(start) params["start"] = str(start)
params["end"] = str(end) params["end"] = str(end)
sql_query += " ORDER BY bm25(entities_fts), entities.file_created_at DESC LIMIT :limit" sql_query += (
" ORDER BY bm25(entities_fts), entities.file_created_at DESC LIMIT :limit"
)
result = db.execute(text(sql_query), params).fetchall() result = db.execute(text(sql_query), params).fetchall()
@ -489,7 +503,7 @@ def full_text_search(
return ids return ids
async def vec_search( def vec_search(
query: str, query: str,
db: Session, db: Session,
limit: int = 200, limit: int = 200,
@ -497,7 +511,7 @@ async def vec_search(
start: Optional[int] = None, start: Optional[int] = None,
end: Optional[int] = None, end: Optional[int] = None,
) -> List[int]: ) -> List[int]:
query_embedding = await get_embeddings([query]) query_embedding = get_embeddings([query])
if not query_embedding: if not query_embedding:
return [] return []
@ -517,9 +531,7 @@ async def vec_search(
sql_query += f" AND entities.library_id IN ({library_ids_str})" sql_query += f" AND entities.library_id IN ({library_ids_str})"
if start is not None and end is not None: if start is not None and end is not None:
sql_query += ( sql_query += " AND strftime('%s', entities.file_created_at, 'utc') BETWEEN :start AND :end"
" AND strftime('%s', entities.file_created_at, 'utc') BETWEEN :start AND :end"
)
params["start"] = str(start) params["start"] = str(start)
params["end"] = str(end) params["end"] = str(end)
@ -547,7 +559,7 @@ def reciprocal_rank_fusion(
return sorted_results return sorted_results
async def hybrid_search( def hybrid_search(
query: str, query: str,
db: Session, db: Session,
limit: int = 200, limit: int = 200,
@ -563,7 +575,7 @@ async def hybrid_search(
logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds") logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds")
vec_start = time.time() vec_start = time.time()
vec_results = await vec_search(query, db, limit, library_ids, start, end) vec_results = vec_search(query, db, limit, library_ids, start, end)
vec_end = time.time() vec_end = time.time()
logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds") logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds")
@ -595,7 +607,7 @@ async def hybrid_search(
return result return result
async def list_entities( def list_entities(
db: Session, db: Session,
limit: int = 200, limit: int = 200,
library_ids: Optional[List[int]] = None, library_ids: Optional[List[int]] = None,
@ -609,7 +621,7 @@ async def list_entities(
if start is not None and end is not None: if start is not None and end is not None:
query = query.filter( query = query.filter(
func.strftime("%s", EntityModel.file_created_at, 'utc').between( func.strftime("%s", EntityModel.file_created_at, "utc").between(
str(start), str(end) str(start), str(end)
) )
) )
@ -646,7 +658,7 @@ def get_entity_context(
db.query(EntityModel) db.query(EntityModel)
.filter( .filter(
EntityModel.library_id == library_id, EntityModel.library_id == library_id,
EntityModel.file_created_at < target_entity.file_created_at EntityModel.file_created_at < target_entity.file_created_at,
) )
.order_by(EntityModel.file_created_at.desc()) .order_by(EntityModel.file_created_at.desc())
.limit(prev) .limit(prev)
@ -662,7 +674,7 @@ def get_entity_context(
db.query(EntityModel) db.query(EntityModel)
.filter( .filter(
EntityModel.library_id == library_id, EntityModel.library_id == library_id,
EntityModel.file_created_at > target_entity.file_created_at EntityModel.file_created_at > target_entity.file_created_at,
) )
.order_by(EntityModel.file_created_at.asc()) .order_by(EntityModel.file_created_at.asc())
.limit(next) .limit(next)
@ -672,3 +684,169 @@ def get_entity_context(
next_entities = [Entity(**entity.__dict__) for entity in next_entities] next_entities = [Entity(**entity.__dict__) for entity in next_entities]
return prev_entities, next_entities return prev_entities, next_entities
def process_ocr_result(value, max_length=4096):
try:
ocr_data = json.loads(value)
if isinstance(ocr_data, list) and all(
isinstance(item, dict)
and "dt_boxes" in item
and "rec_txt" in item
and "score" in item
for item in ocr_data
):
return " ".join(item["rec_txt"] for item in ocr_data[:max_length])
else:
return json.dumps(ocr_data, indent=2)
except json.JSONDecodeError:
return value
def prepare_fts_data(entity: EntityModel) -> tuple[str, str]:
tags = ", ".join([tag.name for tag in entity.tags])
fts_metadata = "\n".join(
[
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
for entry in entity.metadata_entries
]
)
return tags, fts_metadata
def prepare_vec_data(entity: EntityModel) -> str:
vec_metadata = "\n".join(
[
f"{entry.key}: {entry.value}"
for entry in entity.metadata_entries
if entry.key != "ocr_result"
]
)
ocr_result = next(
(entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"),
"",
)
vec_metadata += f"\nocr_result: {process_ocr_result(ocr_result, max_length=128)}"
return vec_metadata
def update_entity_index(entity: EntityModel, db: Session):
"""Update both FTS and vector indexes for an entity"""
try:
# Update FTS index
tags, fts_metadata = prepare_fts_data(entity)
db.execute(
text(
"""
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
VALUES(:id, :filepath, :tags, :metadata)
"""
),
{
"id": entity.id,
"filepath": entity.filepath,
"tags": tags,
"metadata": fts_metadata,
},
)
# Update vector index
vec_metadata = prepare_vec_data(entity)
embeddings = get_embeddings([vec_metadata])
if embeddings and embeddings[0]:
db.execute(
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id}
)
db.execute(
text(
"""
INSERT INTO entities_vec (rowid, embedding)
VALUES (:id, :embedding)
"""
),
{
"id": entity.id,
"embedding": serialize_float32(embeddings[0]),
},
)
db.commit()
except Exception as e:
logger.error(f"Error updating indexes for entity {entity.id}: {e}")
db.rollback()
raise
def batch_update_entity_indices(entity_ids: List[int], db: Session):
"""Batch update both FTS and vector indexes for multiple entities"""
try:
# 获取实体
entities = db.query(EntityModel).filter(EntityModel.id.in_(entity_ids)).all()
found_ids = {entity.id for entity in entities}
# 检查是否所有请求的实体都找到了
missing_ids = set(entity_ids) - found_ids
if missing_ids:
raise ValueError(f"Entities not found: {missing_ids}")
# Prepare FTS data for all entities
fts_data = []
vec_metadata_list = []
for entity in entities:
# Prepare FTS data
tags, fts_metadata = prepare_fts_data(entity)
fts_data.append((entity.id, entity.filepath, tags, fts_metadata))
# Prepare vector data
vec_metadata = prepare_vec_data(entity)
vec_metadata_list.append(vec_metadata)
# Batch update FTS table
for entity_id, filepath, tags, metadata in fts_data:
db.execute(
text(
"""
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
VALUES(:id, :filepath, :tags, :metadata)
"""
),
{
"id": entity_id,
"filepath": filepath,
"tags": tags,
"metadata": metadata,
},
)
# Batch get embeddings
embeddings = get_embeddings(vec_metadata_list)
# Batch update vector table
if embeddings:
for entity, embedding in zip(entities, embeddings):
if embedding: # Check if embedding is not empty
db.execute(
text("DELETE FROM entities_vec WHERE rowid = :id"),
{"id": entity.id},
)
db.execute(
text(
"""
INSERT INTO entities_vec (rowid, embedding)
VALUES (:id, :embedding)
"""
),
{
"id": entity.id,
"embedding": serialize_float32(embedding),
},
)
db.commit()
except Exception as e:
logger.error(f"Error batch updating indexes: {e}")
db.rollback()
raise

View File

@ -58,11 +58,11 @@ def generate_embeddings(texts: List[str]) -> List[List[float]]:
return embeddings.tolist() return embeddings.tolist()
async def get_embeddings(texts: List[str]) -> List[List[float]]: def get_embeddings(texts: List[str]) -> List[List[float]]:
if settings.embedding.use_local: if settings.embedding.use_local:
embeddings = generate_embeddings(texts) embeddings = generate_embeddings(texts)
else: else:
embeddings = await get_remote_embeddings(texts) embeddings = get_remote_embeddings(texts)
# Round the embedding values to 5 decimal places # Round the embedding values to 5 decimal places
return [ return [
@ -71,12 +71,12 @@ async def get_embeddings(texts: List[str]) -> List[List[float]]:
] ]
async def get_remote_embeddings(texts: List[str]) -> List[List[float]]: def get_remote_embeddings(texts: List[str]) -> List[List[float]]:
payload = {"model": settings.embedding.model, "input": texts} payload = {"model": settings.embedding.model, "input": texts}
async with httpx.AsyncClient() as client: with httpx.Client() as client:
try: try:
response = await client.post(settings.embedding.endpoint, json=payload) response = client.post(settings.embedding.endpoint, json=payload)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
return result["embeddings"] return result["embeddings"]

View File

@ -21,11 +21,6 @@ from sqlalchemy import text
import sqlite_vec import sqlite_vec
import sys import sys
from pathlib import Path from pathlib import Path
import json
from .embedding import get_embeddings
from sqlite_vec import serialize_float32
import asyncio
import threading
class Base(DeclarativeBase): class Base(DeclarativeBase):
@ -246,7 +241,13 @@ def recreate_fts_and_vec_tables():
def init_database(): def init_database():
"""Initialize the database.""" """Initialize the database."""
db_path = get_database_path() db_path = get_database_path()
engine = create_engine(f"sqlite:///{db_path}") engine = create_engine(
f"sqlite:///{db_path}",
pool_size=3,
max_overflow=0,
pool_timeout=30,
connect_args={"timeout": 30}
)
# Use a single event listener for both extension loading and WAL mode setting # Use a single event listener for both extension loading and WAL mode setting
event.listen(engine, "connect", load_extension) event.listen(engine, "connect", load_extension)
@ -339,142 +340,3 @@ def init_default_libraries(session, default_plugins):
session.add(library_plugin) session.add(library_plugin)
session.commit() session.commit()
async def update_or_insert_entities_vec(session, target_id, embedding):
try:
session.execute(
text("DELETE FROM entities_vec WHERE rowid = :id"),
{"id": target_id}
)
session.execute(
text(
"""
INSERT INTO entities_vec (rowid, embedding)
VALUES (:id, :embedding)
"""
),
{
"id": target_id,
"embedding": serialize_float32(embedding),
},
)
session.commit()
except Exception as e:
print(f"Error updating entities_vec: {e}")
session.rollback()
def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata):
try:
session.execute(
text(
"""
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
VALUES(:id, :filepath, :tags, :metadata)
"""
),
{
"id": target_id,
"filepath": filepath,
"tags": tags,
"metadata": metadata,
},
)
session.commit()
except Exception as e:
print(f"Error updating entities_fts: {e}")
session.rollback()
async def update_fts_and_vec(mapper, connection, entity: EntityModel):
session = Session(bind=connection)
# Prepare FTS data
tags = ", ".join([tag.name for tag in entity.tags])
# Process metadata entries
def process_ocr_result(value, max_length=4096):
try:
ocr_data = json.loads(value)
if isinstance(ocr_data, list) and all(
isinstance(item, dict)
and "dt_boxes" in item
and "rec_txt" in item
and "score" in item
for item in ocr_data
):
return " ".join(item["rec_txt"] for item in ocr_data[:max_length])
else:
return json.dumps(ocr_data, indent=2)
except json.JSONDecodeError:
return value
fts_metadata = "\n".join(
[
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
for entry in entity.metadata_entries
]
)
# Update FTS table
update_or_insert_entities_fts(
session, entity.id, entity.filepath, tags, fts_metadata
)
# Prepare vector data
metadata_text = "\n".join(
[
f"{entry.key}: {entry.value}"
for entry in entity.metadata_entries
if entry.key != "ocr_result"
]
)
# Add ocr_result at the end of metadata_text using process_ocr_result
ocr_result = next(
(entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"),
"",
)
processed_ocr_result = process_ocr_result(ocr_result, max_length=128)
metadata_text += f"\nocr_result: {processed_ocr_result}"
# Use the new get_embeddings function
embeddings = await get_embeddings([metadata_text])
if not embeddings:
embedding = []
else:
embedding = embeddings[0]
# Update vector table
if embedding:
await update_or_insert_entities_vec(session, entity.id, embedding)
def delete_fts_and_vec(mapper, connection, entity: EntityModel):
connection.execute(
text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity.id}
)
connection.execute(
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id}
)
def run_async(coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
def run_in_thread():
run_async(update_fts_and_vec(mapper, connection, entity))
thread = threading.Thread(target=run_in_thread)
thread.start()
thread.join()
# Add event listeners for EntityModel
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)

View File

@ -276,3 +276,7 @@ class SearchResult(BaseModel):
class EntityContext(BaseModel): class EntityContext(BaseModel):
prev: List[Entity] prev: List[Entity]
next: List[Entity] next: List[Entity]
class BatchIndexRequest(BaseModel):
entity_ids: List[int]

View File

@ -43,6 +43,7 @@ from .schemas import (
SearchHit, SearchHit,
RequestParams, RequestParams,
EntityContext, EntityContext,
BatchIndexRequest,
) )
from .read_metadata import read_metadata from .read_metadata import read_metadata
from .logging_config import LOGGING_CONFIG from .logging_config import LOGGING_CONFIG
@ -239,6 +240,7 @@ async def new_entity(
db: Session = Depends(get_db), db: Session = Depends(get_db),
plugins: Annotated[List[int] | None, Query()] = None, plugins: Annotated[List[int] | None, Query()] = None,
trigger_webhooks_flag: bool = True, trigger_webhooks_flag: bool = True,
update_index: bool = False,
): ):
library = crud.get_library_by_id(library_id, db) library = crud.get_library_by_id(library_id, db)
if library is None: if library is None:
@ -249,6 +251,10 @@ async def new_entity(
entity = crud.create_entity(library_id, new_entity, db) entity = crud.create_entity(library_id, new_entity, db)
if trigger_webhooks_flag: if trigger_webhooks_flag:
await trigger_webhooks(library, entity, request, plugins) await trigger_webhooks(library, entity, request, plugins)
if update_index:
crud.update_entity_index(entity, db)
return entity return entity
@ -346,6 +352,7 @@ async def update_entity(
db: Session = Depends(get_db), db: Session = Depends(get_db),
trigger_webhooks_flag: bool = False, trigger_webhooks_flag: bool = False,
plugins: Annotated[List[int] | None, Query()] = None, plugins: Annotated[List[int] | None, Query()] = None,
update_index: bool = False,
): ):
entity = crud.find_entity_by_id(entity_id, db) entity = crud.find_entity_by_id(entity_id, db)
if entity is None: if entity is None:
@ -364,6 +371,10 @@ async def update_entity(
status_code=status.HTTP_404_NOT_FOUND, detail="Library not found" status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
) )
await trigger_webhooks(library, entity, request, plugins) await trigger_webhooks(library, entity, request, plugins)
if update_index:
crud.update_entity_index(entity, db)
return entity return entity
@ -384,6 +395,46 @@ def update_entity_last_scan_at(entity_id: int, db: Session = Depends(get_db)):
) )
@app.post(
"/entities/{entity_id}/index",
status_code=status.HTTP_204_NO_CONTENT,
tags=["entity"],
)
def update_index(entity_id: int, db: Session = Depends(get_db)):
"""
Update the FTS and vector indexes for an entity.
"""
entity = crud.get_entity_by_id(entity_id, db)
if entity is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Entity not found",
)
crud.update_entity_index(entity, db)
@app.post(
"/entities/batch-index",
status_code=status.HTTP_204_NO_CONTENT,
tags=["entity"],
)
async def batch_update_index(
request: BatchIndexRequest,
db: Session = Depends(get_db)
):
"""
Batch update the FTS and vector indexes for multiple entities.
"""
try:
crud.batch_update_entity_indices(request.entity_ids, db)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
)
@app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"]) @app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"])
def replace_entity_tags( def replace_entity_tags(
entity_id: int, update_tags: UpdateEntityTagsParam, db: Session = Depends(get_db) entity_id: int, update_tags: UpdateEntityTagsParam, db: Session = Depends(get_db)
@ -614,12 +665,12 @@ async def search_entities_v2(
try: try:
if q.strip() == "": if q.strip() == "":
# Use list_entities when q is empty # Use list_entities when q is empty
entities = await crud.list_entities( entities = crud.list_entities(
db=db, limit=limit, library_ids=library_ids, start=start, end=end db=db, limit=limit, library_ids=library_ids, start=start, end=end
) )
else: else:
# Use hybrid_search when q is not empty # Use hybrid_search when q is not empty
entities = await crud.hybrid_search( entities = crud.hybrid_search(
query=q, query=q,
db=db, db=db,
limit=limit, limit=limit,

View File

@ -85,6 +85,10 @@ def setup_library_with_entity(client):
assert entity_response.status_code == 200 assert entity_response.status_code == 200
entity_id = entity_response.json()["id"] entity_id = entity_response.json()["id"]
# Update the entity's index
index_response = client.post(f"/entities/{entity_id}/index")
assert index_response.status_code == 204
return library_id, folder_id, entity_id return library_id, folder_id, entity_id