mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
refact: change sql to update vec and fts
This commit is contained in:
parent
debb37797b
commit
19f36a6807
@ -326,10 +326,12 @@ def init_default_libraries(session, default_plugins):
|
||||
bind_response = session.query(PluginModel).filter_by(name=plugin.name).first()
|
||||
if bind_response:
|
||||
# Check if the LibraryPluginModel already exists
|
||||
existing_library_plugin = session.query(LibraryPluginModel).filter_by(
|
||||
library_id=1, plugin_id=bind_response.id
|
||||
).first()
|
||||
|
||||
existing_library_plugin = (
|
||||
session.query(LibraryPluginModel)
|
||||
.filter_by(library_id=1, plugin_id=bind_response.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_library_plugin:
|
||||
library_plugin = LibraryPluginModel(
|
||||
library_id=1, plugin_id=bind_response.id
|
||||
@ -341,27 +343,23 @@ def init_default_libraries(session, default_plugins):
|
||||
|
||||
async def update_or_insert_entities_vec(session, target_id, embedding):
|
||||
try:
|
||||
# First, try to update the existing row
|
||||
result = session.execute(
|
||||
text("UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"),
|
||||
session.execute(
|
||||
text("DELETE FROM entities_vec WHERE rowid = :id"),
|
||||
{"id": target_id}
|
||||
)
|
||||
|
||||
session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO entities_vec (rowid, embedding)
|
||||
VALUES (:id, :embedding)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": target_id,
|
||||
"embedding": serialize_float32(embedding),
|
||||
},
|
||||
)
|
||||
|
||||
# If no row was updated (i.e., the row doesn't exist), then insert a new row
|
||||
if result.rowcount == 0:
|
||||
session.execute(
|
||||
text(
|
||||
"INSERT INTO entities_vec (rowid, embedding) VALUES (:id, :embedding)"
|
||||
),
|
||||
{
|
||||
"id": target_id,
|
||||
"embedding": serialize_float32(embedding),
|
||||
},
|
||||
)
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
print(f"Error updating entities_vec: {e}")
|
||||
@ -370,13 +368,11 @@ async def update_or_insert_entities_vec(session, target_id, embedding):
|
||||
|
||||
def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata):
|
||||
try:
|
||||
# First, try to update the existing row
|
||||
result = session.execute(
|
||||
session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE entities_fts
|
||||
SET filepath = :filepath, tags = :tags, metadata = :metadata
|
||||
WHERE id = :id
|
||||
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
|
||||
VALUES(:id, :filepath, :tags, :metadata)
|
||||
"""
|
||||
),
|
||||
{
|
||||
@ -386,35 +382,17 @@ def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata):
|
||||
"metadata": metadata,
|
||||
},
|
||||
)
|
||||
|
||||
# If no row was updated (i.e., the row doesn't exist), then insert a new row
|
||||
if result.rowcount == 0:
|
||||
session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO entities_fts(id, filepath, tags, metadata)
|
||||
VALUES(:id, :filepath, :tags, :metadata)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": target_id,
|
||||
"filepath": filepath,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
},
|
||||
)
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
print(f"Error updating entities_fts: {e}")
|
||||
session.rollback()
|
||||
|
||||
|
||||
async def update_fts_and_vec(mapper, connection, target):
|
||||
async def update_fts_and_vec(mapper, connection, entity: EntityModel):
|
||||
session = Session(bind=connection)
|
||||
|
||||
# Prepare FTS data
|
||||
tags = ", ".join([tag.name for tag in target.tags])
|
||||
tags = ", ".join([tag.name for tag in entity.tags])
|
||||
|
||||
# Process metadata entries
|
||||
def process_ocr_result(value, max_length=4096):
|
||||
@ -436,26 +414,28 @@ async def update_fts_and_vec(mapper, connection, target):
|
||||
fts_metadata = "\n".join(
|
||||
[
|
||||
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
|
||||
for entry in target.metadata_entries
|
||||
for entry in entity.metadata_entries
|
||||
]
|
||||
)
|
||||
|
||||
# Update FTS table
|
||||
update_or_insert_entities_fts(session, target.id, target.filepath, tags, fts_metadata)
|
||||
update_or_insert_entities_fts(
|
||||
session, entity.id, entity.filepath, tags, fts_metadata
|
||||
)
|
||||
|
||||
# Prepare vector data
|
||||
metadata_text = "\n".join(
|
||||
[
|
||||
f"{entry.key}: {entry.value}"
|
||||
for entry in target.metadata_entries
|
||||
for entry in entity.metadata_entries
|
||||
if entry.key != "ocr_result"
|
||||
]
|
||||
)
|
||||
|
||||
# Add ocr_result at the end of metadata_text using process_ocr_result
|
||||
ocr_result = next(
|
||||
(entry.value for entry in target.metadata_entries if entry.key == "ocr_result"),
|
||||
""
|
||||
(entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"),
|
||||
"",
|
||||
)
|
||||
processed_ocr_result = process_ocr_result(ocr_result, max_length=128)
|
||||
metadata_text += f"\nocr_result: {processed_ocr_result}"
|
||||
@ -469,15 +449,15 @@ async def update_fts_and_vec(mapper, connection, target):
|
||||
|
||||
# Update vector table
|
||||
if embedding:
|
||||
await update_or_insert_entities_vec(session, target.id, embedding)
|
||||
await update_or_insert_entities_vec(session, entity.id, embedding)
|
||||
|
||||
|
||||
def delete_fts_and_vec(mapper, connection, target):
|
||||
def delete_fts_and_vec(mapper, connection, entity: EntityModel):
|
||||
connection.execute(
|
||||
text("DELETE FROM entities_fts WHERE id = :id"), {"id": target.id}
|
||||
text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity.id}
|
||||
)
|
||||
connection.execute(
|
||||
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": target.id}
|
||||
text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id}
|
||||
)
|
||||
|
||||
|
||||
@ -487,9 +467,9 @@ def run_async(coro):
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
|
||||
def update_fts_and_vec_sync(mapper, connection, target):
|
||||
def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
|
||||
def run_in_thread():
|
||||
run_async(update_fts_and_vec(mapper, connection, target))
|
||||
run_async(update_fts_and_vec(mapper, connection, entity))
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
|
Loading…
x
Reference in New Issue
Block a user