refact: change sql to update vec and fts

This commit is contained in:
arkohut 2024-11-05 21:49:23 +08:00
parent debb37797b
commit 19f36a6807

View File

@ -326,10 +326,12 @@ def init_default_libraries(session, default_plugins):
bind_response = session.query(PluginModel).filter_by(name=plugin.name).first() bind_response = session.query(PluginModel).filter_by(name=plugin.name).first()
if bind_response: if bind_response:
# Check if the LibraryPluginModel already exists # Check if the LibraryPluginModel already exists
existing_library_plugin = session.query(LibraryPluginModel).filter_by( existing_library_plugin = (
library_id=1, plugin_id=bind_response.id session.query(LibraryPluginModel)
).first() .filter_by(library_id=1, plugin_id=bind_response.id)
.first()
)
if not existing_library_plugin: if not existing_library_plugin:
library_plugin = LibraryPluginModel( library_plugin = LibraryPluginModel(
library_id=1, plugin_id=bind_response.id library_id=1, plugin_id=bind_response.id
@ -341,27 +343,23 @@ def init_default_libraries(session, default_plugins):
async def update_or_insert_entities_vec(session, target_id, embedding): async def update_or_insert_entities_vec(session, target_id, embedding):
try: try:
# First, try to update the existing row session.execute(
result = session.execute( text("DELETE FROM entities_vec WHERE rowid = :id"),
text("UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"), {"id": target_id}
)
session.execute(
text(
"""
INSERT INTO entities_vec (rowid, embedding)
VALUES (:id, :embedding)
"""
),
{ {
"id": target_id, "id": target_id,
"embedding": serialize_float32(embedding), "embedding": serialize_float32(embedding),
}, },
) )
# If no row was updated (i.e., the row doesn't exist), then insert a new row
if result.rowcount == 0:
session.execute(
text(
"INSERT INTO entities_vec (rowid, embedding) VALUES (:id, :embedding)"
),
{
"id": target_id,
"embedding": serialize_float32(embedding),
},
)
session.commit() session.commit()
except Exception as e: except Exception as e:
print(f"Error updating entities_vec: {e}") print(f"Error updating entities_vec: {e}")
@ -370,13 +368,11 @@ async def update_or_insert_entities_vec(session, target_id, embedding):
def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata): def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata):
try: try:
# First, try to update the existing row session.execute(
result = session.execute(
text( text(
""" """
UPDATE entities_fts INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
SET filepath = :filepath, tags = :tags, metadata = :metadata VALUES(:id, :filepath, :tags, :metadata)
WHERE id = :id
""" """
), ),
{ {
@ -386,35 +382,17 @@ def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata):
"metadata": metadata, "metadata": metadata,
}, },
) )
# If no row was updated (i.e., the row doesn't exist), then insert a new row
if result.rowcount == 0:
session.execute(
text(
"""
INSERT INTO entities_fts(id, filepath, tags, metadata)
VALUES(:id, :filepath, :tags, :metadata)
"""
),
{
"id": target_id,
"filepath": filepath,
"tags": tags,
"metadata": metadata,
},
)
session.commit() session.commit()
except Exception as e: except Exception as e:
print(f"Error updating entities_fts: {e}") print(f"Error updating entities_fts: {e}")
session.rollback() session.rollback()
async def update_fts_and_vec(mapper, connection, target): async def update_fts_and_vec(mapper, connection, entity: EntityModel):
session = Session(bind=connection) session = Session(bind=connection)
# Prepare FTS data # Prepare FTS data
tags = ", ".join([tag.name for tag in target.tags]) tags = ", ".join([tag.name for tag in entity.tags])
# Process metadata entries # Process metadata entries
def process_ocr_result(value, max_length=4096): def process_ocr_result(value, max_length=4096):
@ -436,26 +414,28 @@ async def update_fts_and_vec(mapper, connection, target):
fts_metadata = "\n".join( fts_metadata = "\n".join(
[ [
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}" f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
for entry in target.metadata_entries for entry in entity.metadata_entries
] ]
) )
# Update FTS table # Update FTS table
update_or_insert_entities_fts(session, target.id, target.filepath, tags, fts_metadata) update_or_insert_entities_fts(
session, entity.id, entity.filepath, tags, fts_metadata
)
# Prepare vector data # Prepare vector data
metadata_text = "\n".join( metadata_text = "\n".join(
[ [
f"{entry.key}: {entry.value}" f"{entry.key}: {entry.value}"
for entry in target.metadata_entries for entry in entity.metadata_entries
if entry.key != "ocr_result" if entry.key != "ocr_result"
] ]
) )
# Add ocr_result at the end of metadata_text using process_ocr_result # Add ocr_result at the end of metadata_text using process_ocr_result
ocr_result = next( ocr_result = next(
(entry.value for entry in target.metadata_entries if entry.key == "ocr_result"), (entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"),
"" "",
) )
processed_ocr_result = process_ocr_result(ocr_result, max_length=128) processed_ocr_result = process_ocr_result(ocr_result, max_length=128)
metadata_text += f"\nocr_result: {processed_ocr_result}" metadata_text += f"\nocr_result: {processed_ocr_result}"
@ -469,15 +449,15 @@ async def update_fts_and_vec(mapper, connection, target):
# Update vector table # Update vector table
if embedding: if embedding:
await update_or_insert_entities_vec(session, target.id, embedding) await update_or_insert_entities_vec(session, entity.id, embedding)
def delete_fts_and_vec(mapper, connection, target): def delete_fts_and_vec(mapper, connection, entity: EntityModel):
connection.execute( connection.execute(
text("DELETE FROM entities_fts WHERE id = :id"), {"id": target.id} text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity.id}
) )
connection.execute( connection.execute(
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": target.id} text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id}
) )
@ -487,9 +467,9 @@ def run_async(coro):
return loop.run_until_complete(coro) return loop.run_until_complete(coro)
def update_fts_and_vec_sync(mapper, connection, target): def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
def run_in_thread(): def run_in_thread():
run_async(update_fts_and_vec(mapper, connection, target)) run_async(update_fts_and_vec(mapper, connection, entity))
thread = threading.Thread(target=run_in_thread) thread = threading.Thread(target=run_in_thread)
thread.start() thread.start()