mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat: concurrency problem solved
This commit is contained in:
parent
22c2158a58
commit
7adb08e6d9
@ -507,7 +507,10 @@ async def add_entity(
|
|||||||
post_response = await client.post(
|
post_response = await client.post(
|
||||||
f"{BASE_URL}/libraries/{library_id}/entities",
|
f"{BASE_URL}/libraries/{library_id}/entities",
|
||||||
json=new_entity,
|
json=new_entity,
|
||||||
params={"plugins": plugins} if plugins else {},
|
params={
|
||||||
|
"plugins": plugins,
|
||||||
|
"update_index": "true"
|
||||||
|
} if plugins else {"update_index": "true"},
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
if 200 <= post_response.status_code < 300:
|
if 200 <= post_response.status_code < 300:
|
||||||
@ -546,6 +549,7 @@ async def update_entity(
|
|||||||
json=new_entity,
|
json=new_entity,
|
||||||
params={
|
params={
|
||||||
"trigger_webhooks_flag": "true",
|
"trigger_webhooks_flag": "true",
|
||||||
|
"update_index": "true",
|
||||||
**({"plugins": plugins} if plugins else {}),
|
**({"plugins": plugins} if plugins else {}),
|
||||||
},
|
},
|
||||||
timeout=60,
|
timeout=60,
|
||||||
@ -581,6 +585,7 @@ def reindex(
|
|||||||
force: bool = typer.Option(
|
force: bool = typer.Option(
|
||||||
False, "--force", help="Force recreate FTS and vector tables before reindexing"
|
False, "--force", help="Force recreate FTS and vector tables before reindexing"
|
||||||
),
|
),
|
||||||
|
batch_size: int = typer.Option(1, "--batch-size", "-bs", help="Batch size for processing entities"),
|
||||||
):
|
):
|
||||||
print(f"Reindexing library {library_id}")
|
print(f"Reindexing library {library_id}")
|
||||||
|
|
||||||
@ -608,7 +613,7 @@ def reindex(
|
|||||||
recreate_fts_and_vec_tables()
|
recreate_fts_and_vec_tables()
|
||||||
print("FTS and vector tables have been recreated.")
|
print("FTS and vector tables have been recreated.")
|
||||||
|
|
||||||
with httpx.Session() as client:
|
with httpx.Client() as client:
|
||||||
total_entities = 0
|
total_entities = 0
|
||||||
|
|
||||||
# Get total entity count for all folders
|
# Get total entity count for all folders
|
||||||
@ -647,56 +652,34 @@ def reindex(
|
|||||||
if not entities:
|
if not entities:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Update last_scan_at for each entity
|
# 收集需要处理的实体 ID
|
||||||
for entity in entities:
|
entity_ids = [
|
||||||
if entity["id"] in scanned_entities:
|
entity["id"]
|
||||||
continue
|
for entity in entities
|
||||||
|
if entity["id"] not in scanned_entities
|
||||||
update_response = client.post(
|
]
|
||||||
f"{BASE_URL}/entities/{entity['id']}/last-scan-at"
|
|
||||||
)
|
# 按 batch_size 分批处理
|
||||||
if update_response.status_code != 204:
|
for i in range(0, len(entity_ids), batch_size):
|
||||||
print(
|
batch_ids = entity_ids[i:i + batch_size]
|
||||||
f"Failed to update last_scan_at for entity {entity['id']}: {update_response.status_code} - {update_response.text}"
|
if batch_ids:
|
||||||
|
batch_response = client.post(
|
||||||
|
f"{BASE_URL}/entities/batch-index",
|
||||||
|
json={"entity_ids": batch_ids},
|
||||||
|
timeout=60,
|
||||||
)
|
)
|
||||||
else:
|
if batch_response.status_code != 204:
|
||||||
scanned_entities.add(entity["id"])
|
print(
|
||||||
|
f"Failed to update batch: {batch_response.status_code} - {batch_response.text}"
|
||||||
pbar.update(1)
|
)
|
||||||
|
pbar.update(len(batch_ids))
|
||||||
|
scanned_entities.update(batch_ids)
|
||||||
|
|
||||||
offset += limit
|
offset += limit
|
||||||
|
|
||||||
print(f"Reindexing completed for library {library_id}")
|
print(f"Reindexing completed for library {library_id}")
|
||||||
|
|
||||||
|
|
||||||
async def check_and_index_entity(client, entity_id, entity_last_scan_at):
|
|
||||||
try:
|
|
||||||
index_response = client.get(f"{BASE_URL}/entities/{entity_id}/index")
|
|
||||||
if index_response.status_code == 200:
|
|
||||||
index_data = index_response.json()
|
|
||||||
if index_data["last_scan_at"] is None:
|
|
||||||
return entity_last_scan_at is not None
|
|
||||||
index_last_scan_at = datetime.fromtimestamp(index_data["last_scan_at"])
|
|
||||||
entity_last_scan_at = datetime.fromisoformat(entity_last_scan_at)
|
|
||||||
|
|
||||||
if index_last_scan_at >= entity_last_scan_at:
|
|
||||||
return False # Index is up to date, no need to update
|
|
||||||
return True # Index doesn't exist or needs update
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
if e.response.status_code == 404:
|
|
||||||
return True # Index doesn't exist, need to create
|
|
||||||
raise # Re-raise other HTTP errors
|
|
||||||
|
|
||||||
|
|
||||||
async def index_batch(client, entity_ids):
|
|
||||||
index_response = client.post(
|
|
||||||
f"{BASE_URL}/entities/batch-index",
|
|
||||||
json=entity_ids,
|
|
||||||
timeout=60,
|
|
||||||
)
|
|
||||||
return index_response
|
|
||||||
|
|
||||||
|
|
||||||
@lib_app.command("sync")
|
@lib_app.command("sync")
|
||||||
def sync(
|
def sync(
|
||||||
library_id: int,
|
library_id: int,
|
||||||
@ -799,7 +782,10 @@ def sync(
|
|||||||
update_response = httpx.put(
|
update_response = httpx.put(
|
||||||
f"{BASE_URL}/entities/{existing_entity['id']}",
|
f"{BASE_URL}/entities/{existing_entity['id']}",
|
||||||
json=new_entity,
|
json=new_entity,
|
||||||
params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
|
params={
|
||||||
|
"trigger_webhooks_flag": str(not without_webhooks).lower(),
|
||||||
|
"update_index": "true",
|
||||||
|
},
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
if update_response.status_code == 200:
|
if update_response.status_code == 200:
|
||||||
@ -829,7 +815,10 @@ def sync(
|
|||||||
create_response = httpx.post(
|
create_response = httpx.post(
|
||||||
f"{BASE_URL}/libraries/{library_id}/entities",
|
f"{BASE_URL}/libraries/{library_id}/entities",
|
||||||
json=new_entity,
|
json=new_entity,
|
||||||
params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
|
params={
|
||||||
|
"trigger_webhooks_flag": str(not without_webhooks).lower(),
|
||||||
|
"update_index": "true",
|
||||||
|
},
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -191,7 +191,10 @@ def scan_default_library(
|
|||||||
def reindex_default_library(
|
def reindex_default_library(
|
||||||
force: bool = typer.Option(
|
force: bool = typer.Option(
|
||||||
False, "--force", help="Force recreate FTS and vector tables before reindexing"
|
False, "--force", help="Force recreate FTS and vector tables before reindexing"
|
||||||
)
|
),
|
||||||
|
batch_size: int = typer.Option(
|
||||||
|
1, "--batch-size", "-bs", help="Batch size for processing files"
|
||||||
|
),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Reindex the default library for memos.
|
Reindex the default library for memos.
|
||||||
@ -215,7 +218,7 @@ def reindex_default_library(
|
|||||||
|
|
||||||
# Reindex the library
|
# Reindex the library
|
||||||
print(f"Reindexing library: {default_library['name']}")
|
print(f"Reindexing library: {default_library['name']}")
|
||||||
reindex(default_library["id"], force=force, folders=None)
|
reindex(default_library["id"], force=force, folders=None, batch_size=batch_size)
|
||||||
|
|
||||||
|
|
||||||
@app.command("record")
|
@app.command("record")
|
||||||
|
246
memos/crud.py
246
memos/crud.py
@ -25,13 +25,12 @@ from .models import (
|
|||||||
EntityMetadataModel,
|
EntityMetadataModel,
|
||||||
EntityTagModel,
|
EntityTagModel,
|
||||||
)
|
)
|
||||||
import numpy as np
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from .embedding import get_embeddings
|
from .embedding import get_embeddings
|
||||||
import logging
|
import logging
|
||||||
from sqlite_vec import serialize_float32
|
from sqlite_vec import serialize_float32
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -100,7 +99,11 @@ def add_folders(library_id: int, folders: NewFoldersParam, db: Session) -> Libra
|
|||||||
return Library(**db_library.__dict__)
|
return Library(**db_library.__dict__)
|
||||||
|
|
||||||
|
|
||||||
def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entity:
|
def create_entity(
|
||||||
|
library_id: int,
|
||||||
|
entity: NewEntityParam,
|
||||||
|
db: Session,
|
||||||
|
) -> Entity:
|
||||||
tags = entity.tags
|
tags = entity.tags
|
||||||
metadata_entries = entity.metadata_entries
|
metadata_entries = entity.metadata_entries
|
||||||
|
|
||||||
@ -146,6 +149,7 @@ def create_entity(library_id: int, entity: NewEntityParam, db: Session) -> Entit
|
|||||||
db.add(entity_metadata)
|
db.add(entity_metadata)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_entity)
|
db.refresh(db_entity)
|
||||||
|
|
||||||
return Entity(**db_entity.__dict__)
|
return Entity(**db_entity.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@ -185,15 +189,11 @@ def remove_entity(entity_id: int, db: Session):
|
|||||||
entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
||||||
if entity:
|
if entity:
|
||||||
# Delete the entity from FTS and vec tables first
|
# Delete the entity from FTS and vec tables first
|
||||||
|
db.execute(text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity_id})
|
||||||
db.execute(
|
db.execute(
|
||||||
text("DELETE FROM entities_fts WHERE id = :id"),
|
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity_id}
|
||||||
{"id": entity_id}
|
|
||||||
)
|
)
|
||||||
db.execute(
|
|
||||||
text("DELETE FROM entities_vec WHERE rowid = :id"),
|
|
||||||
{"id": entity_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Then delete the entity itself
|
# Then delete the entity itself
|
||||||
db.delete(entity)
|
db.delete(entity)
|
||||||
db.commit()
|
db.commit()
|
||||||
@ -241,7 +241,9 @@ def find_entities_by_ids(entity_ids: List[int], db: Session) -> List[Entity]:
|
|||||||
|
|
||||||
|
|
||||||
def update_entity(
|
def update_entity(
|
||||||
entity_id: int, updated_entity: UpdateEntityParam, db: Session
|
entity_id: int,
|
||||||
|
updated_entity: UpdateEntityParam,
|
||||||
|
db: Session,
|
||||||
) -> Entity:
|
) -> Entity:
|
||||||
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
db_entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
||||||
|
|
||||||
@ -298,6 +300,7 @@ def update_entity(
|
|||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_entity)
|
db.refresh(db_entity)
|
||||||
|
|
||||||
return Entity(**db_entity.__dict__)
|
return Entity(**db_entity.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@ -312,14 +315,18 @@ def touch_entity(entity_id: int, db: Session) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
|
def update_entity_tags(
|
||||||
|
entity_id: int,
|
||||||
|
tags: List[str],
|
||||||
|
db: Session,
|
||||||
|
) -> Entity:
|
||||||
db_entity = get_entity_by_id(entity_id, db)
|
db_entity = get_entity_by_id(entity_id, db)
|
||||||
if not db_entity:
|
if not db_entity:
|
||||||
raise ValueError(f"Entity with id {entity_id} not found")
|
raise ValueError(f"Entity with id {entity_id} not found")
|
||||||
|
|
||||||
# Clear existing tags
|
# Clear existing tags
|
||||||
db.query(EntityTagModel).filter(EntityTagModel.entity_id == entity_id).delete()
|
db.query(EntityTagModel).filter(EntityTagModel.entity_id == entity_id).delete()
|
||||||
|
|
||||||
for tag_name in tags:
|
for tag_name in tags:
|
||||||
tag = db.query(TagModel).filter(TagModel.name == tag_name).first()
|
tag = db.query(TagModel).filter(TagModel.name == tag_name).first()
|
||||||
if not tag:
|
if not tag:
|
||||||
@ -333,12 +340,13 @@ def update_entity_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
|
|||||||
source=MetadataSource.PLUGIN_GENERATED,
|
source=MetadataSource.PLUGIN_GENERATED,
|
||||||
)
|
)
|
||||||
db.add(entity_tag)
|
db.add(entity_tag)
|
||||||
|
|
||||||
# Update last_scan_at in the same transaction
|
# Update last_scan_at in the same transaction
|
||||||
db_entity.last_scan_at = func.now()
|
db_entity.last_scan_at = func.now()
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_entity)
|
db.refresh(db_entity)
|
||||||
|
|
||||||
return Entity(**db_entity.__dict__)
|
return Entity(**db_entity.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@ -363,17 +371,20 @@ def add_new_tags(entity_id: int, tags: List[str], db: Session) -> Entity:
|
|||||||
source=MetadataSource.PLUGIN_GENERATED,
|
source=MetadataSource.PLUGIN_GENERATED,
|
||||||
)
|
)
|
||||||
db.add(entity_tag)
|
db.add(entity_tag)
|
||||||
|
|
||||||
# Update last_scan_at in the same transaction
|
# Update last_scan_at in the same transaction
|
||||||
db_entity.last_scan_at = func.now()
|
db_entity.last_scan_at = func.now()
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_entity)
|
db.refresh(db_entity)
|
||||||
|
|
||||||
return Entity(**db_entity.__dict__)
|
return Entity(**db_entity.__dict__)
|
||||||
|
|
||||||
|
|
||||||
def update_entity_metadata_entries(
|
def update_entity_metadata_entries(
|
||||||
entity_id: int, updated_metadata: List[EntityMetadataParam], db: Session
|
entity_id: int,
|
||||||
|
updated_metadata: List[EntityMetadataParam],
|
||||||
|
db: Session,
|
||||||
) -> Entity:
|
) -> Entity:
|
||||||
db_entity = get_entity_by_id(entity_id, db)
|
db_entity = get_entity_by_id(entity_id, db)
|
||||||
|
|
||||||
@ -421,6 +432,7 @@ def update_entity_metadata_entries(
|
|||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_entity)
|
db.refresh(db_entity)
|
||||||
|
|
||||||
return Entity(**db_entity.__dict__)
|
return Entity(**db_entity.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@ -478,7 +490,9 @@ def full_text_search(
|
|||||||
params["start"] = str(start)
|
params["start"] = str(start)
|
||||||
params["end"] = str(end)
|
params["end"] = str(end)
|
||||||
|
|
||||||
sql_query += " ORDER BY bm25(entities_fts), entities.file_created_at DESC LIMIT :limit"
|
sql_query += (
|
||||||
|
" ORDER BY bm25(entities_fts), entities.file_created_at DESC LIMIT :limit"
|
||||||
|
)
|
||||||
|
|
||||||
result = db.execute(text(sql_query), params).fetchall()
|
result = db.execute(text(sql_query), params).fetchall()
|
||||||
|
|
||||||
@ -489,7 +503,7 @@ def full_text_search(
|
|||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
async def vec_search(
|
def vec_search(
|
||||||
query: str,
|
query: str,
|
||||||
db: Session,
|
db: Session,
|
||||||
limit: int = 200,
|
limit: int = 200,
|
||||||
@ -497,7 +511,7 @@ async def vec_search(
|
|||||||
start: Optional[int] = None,
|
start: Optional[int] = None,
|
||||||
end: Optional[int] = None,
|
end: Optional[int] = None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
query_embedding = await get_embeddings([query])
|
query_embedding = get_embeddings([query])
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -517,9 +531,7 @@ async def vec_search(
|
|||||||
sql_query += f" AND entities.library_id IN ({library_ids_str})"
|
sql_query += f" AND entities.library_id IN ({library_ids_str})"
|
||||||
|
|
||||||
if start is not None and end is not None:
|
if start is not None and end is not None:
|
||||||
sql_query += (
|
sql_query += " AND strftime('%s', entities.file_created_at, 'utc') BETWEEN :start AND :end"
|
||||||
" AND strftime('%s', entities.file_created_at, 'utc') BETWEEN :start AND :end"
|
|
||||||
)
|
|
||||||
params["start"] = str(start)
|
params["start"] = str(start)
|
||||||
params["end"] = str(end)
|
params["end"] = str(end)
|
||||||
|
|
||||||
@ -547,7 +559,7 @@ def reciprocal_rank_fusion(
|
|||||||
return sorted_results
|
return sorted_results
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_search(
|
def hybrid_search(
|
||||||
query: str,
|
query: str,
|
||||||
db: Session,
|
db: Session,
|
||||||
limit: int = 200,
|
limit: int = 200,
|
||||||
@ -563,7 +575,7 @@ async def hybrid_search(
|
|||||||
logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds")
|
logger.info(f"Full-text search took {fts_end - fts_start:.4f} seconds")
|
||||||
|
|
||||||
vec_start = time.time()
|
vec_start = time.time()
|
||||||
vec_results = await vec_search(query, db, limit, library_ids, start, end)
|
vec_results = vec_search(query, db, limit, library_ids, start, end)
|
||||||
vec_end = time.time()
|
vec_end = time.time()
|
||||||
logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds")
|
logger.info(f"Vector search took {vec_end - vec_start:.4f} seconds")
|
||||||
|
|
||||||
@ -595,7 +607,7 @@ async def hybrid_search(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def list_entities(
|
def list_entities(
|
||||||
db: Session,
|
db: Session,
|
||||||
limit: int = 200,
|
limit: int = 200,
|
||||||
library_ids: Optional[List[int]] = None,
|
library_ids: Optional[List[int]] = None,
|
||||||
@ -609,7 +621,7 @@ async def list_entities(
|
|||||||
|
|
||||||
if start is not None and end is not None:
|
if start is not None and end is not None:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
func.strftime("%s", EntityModel.file_created_at, 'utc').between(
|
func.strftime("%s", EntityModel.file_created_at, "utc").between(
|
||||||
str(start), str(end)
|
str(start), str(end)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -635,10 +647,10 @@ def get_entity_context(
|
|||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not target_entity:
|
if not target_entity:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# Get previous entities
|
# Get previous entities
|
||||||
prev_entities = []
|
prev_entities = []
|
||||||
if prev > 0:
|
if prev > 0:
|
||||||
@ -646,7 +658,7 @@ def get_entity_context(
|
|||||||
db.query(EntityModel)
|
db.query(EntityModel)
|
||||||
.filter(
|
.filter(
|
||||||
EntityModel.library_id == library_id,
|
EntityModel.library_id == library_id,
|
||||||
EntityModel.file_created_at < target_entity.file_created_at
|
EntityModel.file_created_at < target_entity.file_created_at,
|
||||||
)
|
)
|
||||||
.order_by(EntityModel.file_created_at.desc())
|
.order_by(EntityModel.file_created_at.desc())
|
||||||
.limit(prev)
|
.limit(prev)
|
||||||
@ -654,7 +666,7 @@ def get_entity_context(
|
|||||||
)
|
)
|
||||||
# Reverse the list to get chronological order and convert to Entity models
|
# Reverse the list to get chronological order and convert to Entity models
|
||||||
prev_entities = [Entity(**entity.__dict__) for entity in prev_entities][::-1]
|
prev_entities = [Entity(**entity.__dict__) for entity in prev_entities][::-1]
|
||||||
|
|
||||||
# Get next entities
|
# Get next entities
|
||||||
next_entities = []
|
next_entities = []
|
||||||
if next > 0:
|
if next > 0:
|
||||||
@ -662,7 +674,7 @@ def get_entity_context(
|
|||||||
db.query(EntityModel)
|
db.query(EntityModel)
|
||||||
.filter(
|
.filter(
|
||||||
EntityModel.library_id == library_id,
|
EntityModel.library_id == library_id,
|
||||||
EntityModel.file_created_at > target_entity.file_created_at
|
EntityModel.file_created_at > target_entity.file_created_at,
|
||||||
)
|
)
|
||||||
.order_by(EntityModel.file_created_at.asc())
|
.order_by(EntityModel.file_created_at.asc())
|
||||||
.limit(next)
|
.limit(next)
|
||||||
@ -670,5 +682,171 @@ def get_entity_context(
|
|||||||
)
|
)
|
||||||
# Convert to Entity models
|
# Convert to Entity models
|
||||||
next_entities = [Entity(**entity.__dict__) for entity in next_entities]
|
next_entities = [Entity(**entity.__dict__) for entity in next_entities]
|
||||||
|
|
||||||
return prev_entities, next_entities
|
return prev_entities, next_entities
|
||||||
|
|
||||||
|
|
||||||
|
def process_ocr_result(value, max_length=4096):
|
||||||
|
try:
|
||||||
|
ocr_data = json.loads(value)
|
||||||
|
if isinstance(ocr_data, list) and all(
|
||||||
|
isinstance(item, dict)
|
||||||
|
and "dt_boxes" in item
|
||||||
|
and "rec_txt" in item
|
||||||
|
and "score" in item
|
||||||
|
for item in ocr_data
|
||||||
|
):
|
||||||
|
return " ".join(item["rec_txt"] for item in ocr_data[:max_length])
|
||||||
|
else:
|
||||||
|
return json.dumps(ocr_data, indent=2)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_fts_data(entity: EntityModel) -> tuple[str, str]:
|
||||||
|
tags = ", ".join([tag.name for tag in entity.tags])
|
||||||
|
fts_metadata = "\n".join(
|
||||||
|
[
|
||||||
|
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
|
||||||
|
for entry in entity.metadata_entries
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return tags, fts_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_vec_data(entity: EntityModel) -> str:
|
||||||
|
vec_metadata = "\n".join(
|
||||||
|
[
|
||||||
|
f"{entry.key}: {entry.value}"
|
||||||
|
for entry in entity.metadata_entries
|
||||||
|
if entry.key != "ocr_result"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ocr_result = next(
|
||||||
|
(entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
vec_metadata += f"\nocr_result: {process_ocr_result(ocr_result, max_length=128)}"
|
||||||
|
return vec_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def update_entity_index(entity: EntityModel, db: Session):
|
||||||
|
"""Update both FTS and vector indexes for an entity"""
|
||||||
|
try:
|
||||||
|
# Update FTS index
|
||||||
|
tags, fts_metadata = prepare_fts_data(entity)
|
||||||
|
db.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
|
||||||
|
VALUES(:id, :filepath, :tags, :metadata)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": entity.id,
|
||||||
|
"filepath": entity.filepath,
|
||||||
|
"tags": tags,
|
||||||
|
"metadata": fts_metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update vector index
|
||||||
|
vec_metadata = prepare_vec_data(entity)
|
||||||
|
embeddings = get_embeddings([vec_metadata])
|
||||||
|
|
||||||
|
if embeddings and embeddings[0]:
|
||||||
|
db.execute(
|
||||||
|
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id}
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO entities_vec (rowid, embedding)
|
||||||
|
VALUES (:id, :embedding)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": entity.id,
|
||||||
|
"embedding": serialize_float32(embeddings[0]),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating indexes for entity {entity.id}: {e}")
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def batch_update_entity_indices(entity_ids: List[int], db: Session):
|
||||||
|
"""Batch update both FTS and vector indexes for multiple entities"""
|
||||||
|
try:
|
||||||
|
# 获取实体
|
||||||
|
entities = db.query(EntityModel).filter(EntityModel.id.in_(entity_ids)).all()
|
||||||
|
found_ids = {entity.id for entity in entities}
|
||||||
|
|
||||||
|
# 检查是否所有请求的实体都找到了
|
||||||
|
missing_ids = set(entity_ids) - found_ids
|
||||||
|
if missing_ids:
|
||||||
|
raise ValueError(f"Entities not found: {missing_ids}")
|
||||||
|
|
||||||
|
# Prepare FTS data for all entities
|
||||||
|
fts_data = []
|
||||||
|
vec_metadata_list = []
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
# Prepare FTS data
|
||||||
|
tags, fts_metadata = prepare_fts_data(entity)
|
||||||
|
fts_data.append((entity.id, entity.filepath, tags, fts_metadata))
|
||||||
|
|
||||||
|
# Prepare vector data
|
||||||
|
vec_metadata = prepare_vec_data(entity)
|
||||||
|
vec_metadata_list.append(vec_metadata)
|
||||||
|
|
||||||
|
# Batch update FTS table
|
||||||
|
for entity_id, filepath, tags, metadata in fts_data:
|
||||||
|
db.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
|
||||||
|
VALUES(:id, :filepath, :tags, :metadata)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": entity_id,
|
||||||
|
"filepath": filepath,
|
||||||
|
"tags": tags,
|
||||||
|
"metadata": metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batch get embeddings
|
||||||
|
embeddings = get_embeddings(vec_metadata_list)
|
||||||
|
|
||||||
|
# Batch update vector table
|
||||||
|
if embeddings:
|
||||||
|
for entity, embedding in zip(entities, embeddings):
|
||||||
|
if embedding: # Check if embedding is not empty
|
||||||
|
db.execute(
|
||||||
|
text("DELETE FROM entities_vec WHERE rowid = :id"),
|
||||||
|
{"id": entity.id},
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO entities_vec (rowid, embedding)
|
||||||
|
VALUES (:id, :embedding)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": entity.id,
|
||||||
|
"embedding": serialize_float32(embedding),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error batch updating indexes: {e}")
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
@ -58,11 +58,11 @@ def generate_embeddings(texts: List[str]) -> List[List[float]]:
|
|||||||
return embeddings.tolist()
|
return embeddings.tolist()
|
||||||
|
|
||||||
|
|
||||||
async def get_embeddings(texts: List[str]) -> List[List[float]]:
|
def get_embeddings(texts: List[str]) -> List[List[float]]:
|
||||||
if settings.embedding.use_local:
|
if settings.embedding.use_local:
|
||||||
embeddings = generate_embeddings(texts)
|
embeddings = generate_embeddings(texts)
|
||||||
else:
|
else:
|
||||||
embeddings = await get_remote_embeddings(texts)
|
embeddings = get_remote_embeddings(texts)
|
||||||
|
|
||||||
# Round the embedding values to 5 decimal places
|
# Round the embedding values to 5 decimal places
|
||||||
return [
|
return [
|
||||||
@ -71,12 +71,12 @@ async def get_embeddings(texts: List[str]) -> List[List[float]]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def get_remote_embeddings(texts: List[str]) -> List[List[float]]:
|
def get_remote_embeddings(texts: List[str]) -> List[List[float]]:
|
||||||
payload = {"model": settings.embedding.model, "input": texts}
|
payload = {"model": settings.embedding.model, "input": texts}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
with httpx.Client() as client:
|
||||||
try:
|
try:
|
||||||
response = await client.post(settings.embedding.endpoint, json=payload)
|
response = client.post(settings.embedding.endpoint, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
return result["embeddings"]
|
return result["embeddings"]
|
||||||
|
152
memos/models.py
152
memos/models.py
@ -21,11 +21,6 @@ from sqlalchemy import text
|
|||||||
import sqlite_vec
|
import sqlite_vec
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
|
||||||
from .embedding import get_embeddings
|
|
||||||
from sqlite_vec import serialize_float32
|
|
||||||
import asyncio
|
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
@ -246,7 +241,13 @@ def recreate_fts_and_vec_tables():
|
|||||||
def init_database():
|
def init_database():
|
||||||
"""Initialize the database."""
|
"""Initialize the database."""
|
||||||
db_path = get_database_path()
|
db_path = get_database_path()
|
||||||
engine = create_engine(f"sqlite:///{db_path}")
|
engine = create_engine(
|
||||||
|
f"sqlite:///{db_path}",
|
||||||
|
pool_size=3,
|
||||||
|
max_overflow=0,
|
||||||
|
pool_timeout=30,
|
||||||
|
connect_args={"timeout": 30}
|
||||||
|
)
|
||||||
|
|
||||||
# Use a single event listener for both extension loading and WAL mode setting
|
# Use a single event listener for both extension loading and WAL mode setting
|
||||||
event.listen(engine, "connect", load_extension)
|
event.listen(engine, "connect", load_extension)
|
||||||
@ -339,142 +340,3 @@ def init_default_libraries(session, default_plugins):
|
|||||||
session.add(library_plugin)
|
session.add(library_plugin)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
async def update_or_insert_entities_vec(session, target_id, embedding):
|
|
||||||
try:
|
|
||||||
session.execute(
|
|
||||||
text("DELETE FROM entities_vec WHERE rowid = :id"),
|
|
||||||
{"id": target_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
INSERT INTO entities_vec (rowid, embedding)
|
|
||||||
VALUES (:id, :embedding)
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": target_id,
|
|
||||||
"embedding": serialize_float32(embedding),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error updating entities_vec: {e}")
|
|
||||||
session.rollback()
|
|
||||||
|
|
||||||
|
|
||||||
def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata):
|
|
||||||
try:
|
|
||||||
session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
|
|
||||||
VALUES(:id, :filepath, :tags, :metadata)
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": target_id,
|
|
||||||
"filepath": filepath,
|
|
||||||
"tags": tags,
|
|
||||||
"metadata": metadata,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error updating entities_fts: {e}")
|
|
||||||
session.rollback()
|
|
||||||
|
|
||||||
|
|
||||||
async def update_fts_and_vec(mapper, connection, entity: EntityModel):
|
|
||||||
session = Session(bind=connection)
|
|
||||||
|
|
||||||
# Prepare FTS data
|
|
||||||
tags = ", ".join([tag.name for tag in entity.tags])
|
|
||||||
|
|
||||||
# Process metadata entries
|
|
||||||
def process_ocr_result(value, max_length=4096):
|
|
||||||
try:
|
|
||||||
ocr_data = json.loads(value)
|
|
||||||
if isinstance(ocr_data, list) and all(
|
|
||||||
isinstance(item, dict)
|
|
||||||
and "dt_boxes" in item
|
|
||||||
and "rec_txt" in item
|
|
||||||
and "score" in item
|
|
||||||
for item in ocr_data
|
|
||||||
):
|
|
||||||
return " ".join(item["rec_txt"] for item in ocr_data[:max_length])
|
|
||||||
else:
|
|
||||||
return json.dumps(ocr_data, indent=2)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return value
|
|
||||||
|
|
||||||
fts_metadata = "\n".join(
|
|
||||||
[
|
|
||||||
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
|
|
||||||
for entry in entity.metadata_entries
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update FTS table
|
|
||||||
update_or_insert_entities_fts(
|
|
||||||
session, entity.id, entity.filepath, tags, fts_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare vector data
|
|
||||||
metadata_text = "\n".join(
|
|
||||||
[
|
|
||||||
f"{entry.key}: {entry.value}"
|
|
||||||
for entry in entity.metadata_entries
|
|
||||||
if entry.key != "ocr_result"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add ocr_result at the end of metadata_text using process_ocr_result
|
|
||||||
ocr_result = next(
|
|
||||||
(entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
processed_ocr_result = process_ocr_result(ocr_result, max_length=128)
|
|
||||||
metadata_text += f"\nocr_result: {processed_ocr_result}"
|
|
||||||
|
|
||||||
# Use the new get_embeddings function
|
|
||||||
embeddings = await get_embeddings([metadata_text])
|
|
||||||
if not embeddings:
|
|
||||||
embedding = []
|
|
||||||
else:
|
|
||||||
embedding = embeddings[0]
|
|
||||||
|
|
||||||
# Update vector table
|
|
||||||
if embedding:
|
|
||||||
await update_or_insert_entities_vec(session, entity.id, embedding)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_fts_and_vec(mapper, connection, entity: EntityModel):
|
|
||||||
connection.execute(
|
|
||||||
text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity.id}
|
|
||||||
)
|
|
||||||
connection.execute(
|
|
||||||
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_async(coro):
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
return loop.run_until_complete(coro)
|
|
||||||
|
|
||||||
|
|
||||||
def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
|
|
||||||
def run_in_thread():
|
|
||||||
run_async(update_fts_and_vec(mapper, connection, entity))
|
|
||||||
|
|
||||||
thread = threading.Thread(target=run_in_thread)
|
|
||||||
thread.start()
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
# Add event listeners for EntityModel
|
|
||||||
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
|
|
||||||
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
|
|
||||||
|
@ -276,3 +276,7 @@ class SearchResult(BaseModel):
|
|||||||
class EntityContext(BaseModel):
|
class EntityContext(BaseModel):
|
||||||
prev: List[Entity]
|
prev: List[Entity]
|
||||||
next: List[Entity]
|
next: List[Entity]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchIndexRequest(BaseModel):
|
||||||
|
entity_ids: List[int]
|
||||||
|
@ -43,6 +43,7 @@ from .schemas import (
|
|||||||
SearchHit,
|
SearchHit,
|
||||||
RequestParams,
|
RequestParams,
|
||||||
EntityContext,
|
EntityContext,
|
||||||
|
BatchIndexRequest,
|
||||||
)
|
)
|
||||||
from .read_metadata import read_metadata
|
from .read_metadata import read_metadata
|
||||||
from .logging_config import LOGGING_CONFIG
|
from .logging_config import LOGGING_CONFIG
|
||||||
@ -239,6 +240,7 @@ async def new_entity(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
plugins: Annotated[List[int] | None, Query()] = None,
|
plugins: Annotated[List[int] | None, Query()] = None,
|
||||||
trigger_webhooks_flag: bool = True,
|
trigger_webhooks_flag: bool = True,
|
||||||
|
update_index: bool = False,
|
||||||
):
|
):
|
||||||
library = crud.get_library_by_id(library_id, db)
|
library = crud.get_library_by_id(library_id, db)
|
||||||
if library is None:
|
if library is None:
|
||||||
@ -249,6 +251,10 @@ async def new_entity(
|
|||||||
entity = crud.create_entity(library_id, new_entity, db)
|
entity = crud.create_entity(library_id, new_entity, db)
|
||||||
if trigger_webhooks_flag:
|
if trigger_webhooks_flag:
|
||||||
await trigger_webhooks(library, entity, request, plugins)
|
await trigger_webhooks(library, entity, request, plugins)
|
||||||
|
|
||||||
|
if update_index:
|
||||||
|
crud.update_entity_index(entity, db)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
|
|
||||||
@ -346,6 +352,7 @@ async def update_entity(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
trigger_webhooks_flag: bool = False,
|
trigger_webhooks_flag: bool = False,
|
||||||
plugins: Annotated[List[int] | None, Query()] = None,
|
plugins: Annotated[List[int] | None, Query()] = None,
|
||||||
|
update_index: bool = False,
|
||||||
):
|
):
|
||||||
entity = crud.find_entity_by_id(entity_id, db)
|
entity = crud.find_entity_by_id(entity_id, db)
|
||||||
if entity is None:
|
if entity is None:
|
||||||
@ -364,6 +371,10 @@ async def update_entity(
|
|||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
|
status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
|
||||||
)
|
)
|
||||||
await trigger_webhooks(library, entity, request, plugins)
|
await trigger_webhooks(library, entity, request, plugins)
|
||||||
|
|
||||||
|
if update_index:
|
||||||
|
crud.update_entity_index(entity, db)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
|
|
||||||
@ -382,6 +393,46 @@ def update_entity_last_scan_at(entity_id: int, db: Session = Depends(get_db)):
|
|||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="Entity not found",
|
detail="Entity not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/entities/{entity_id}/index",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
tags=["entity"],
|
||||||
|
)
|
||||||
|
def update_index(entity_id: int, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Update the FTS and vector indexes for an entity.
|
||||||
|
"""
|
||||||
|
entity = crud.get_entity_by_id(entity_id, db)
|
||||||
|
if entity is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Entity not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
crud.update_entity_index(entity, db)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/entities/batch-index",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
tags=["entity"],
|
||||||
|
)
|
||||||
|
async def batch_update_index(
|
||||||
|
request: BatchIndexRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Batch update the FTS and vector indexes for multiple entities.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
crud.batch_update_entity_indices(request.entity_ids, db)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"])
|
@app.put("/entities/{entity_id}/tags", response_model=Entity, tags=["entity"])
|
||||||
@ -614,12 +665,12 @@ async def search_entities_v2(
|
|||||||
try:
|
try:
|
||||||
if q.strip() == "":
|
if q.strip() == "":
|
||||||
# Use list_entities when q is empty
|
# Use list_entities when q is empty
|
||||||
entities = await crud.list_entities(
|
entities = crud.list_entities(
|
||||||
db=db, limit=limit, library_ids=library_ids, start=start, end=end
|
db=db, limit=limit, library_ids=library_ids, start=start, end=end
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use hybrid_search when q is not empty
|
# Use hybrid_search when q is not empty
|
||||||
entities = await crud.hybrid_search(
|
entities = crud.hybrid_search(
|
||||||
query=q,
|
query=q,
|
||||||
db=db,
|
db=db,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
@ -85,6 +85,10 @@ def setup_library_with_entity(client):
|
|||||||
assert entity_response.status_code == 200
|
assert entity_response.status_code == 200
|
||||||
entity_id = entity_response.json()["id"]
|
entity_id = entity_response.json()["id"]
|
||||||
|
|
||||||
|
# Update the entity's index
|
||||||
|
index_response = client.post(f"/entities/{entity_id}/index")
|
||||||
|
assert index_response.status_code == 204
|
||||||
|
|
||||||
return library_id, folder_id, entity_id
|
return library_id, folder_id, entity_id
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user