feat(model): add ocr to embedding table

This commit is contained in:
arkohut 2024-10-21 14:49:30 +08:00
parent 29d1e2f0ce
commit a1a5e576c0

View File

@ -438,7 +438,7 @@ async def update_fts_and_vec(mapper, connection, target):
tags = ", ".join([tag.name for tag in target.tags]) tags = ", ".join([tag.name for tag in target.tags])
# Process metadata entries # Process metadata entries
def process_ocr_result(value): def process_ocr_result(value, max_length=4096):
try: try:
ocr_data = json.loads(value) ocr_data = json.loads(value)
if isinstance(ocr_data, list) and all( if isinstance(ocr_data, list) and all(
@ -448,13 +448,13 @@ async def update_fts_and_vec(mapper, connection, target):
and "score" in item and "score" in item
for item in ocr_data for item in ocr_data
): ):
return " ".join(item["rec_txt"] for item in ocr_data) return " ".join(item["rec_txt"] for item in ocr_data[:max_length])
else: else:
return json.dumps(ocr_data, indent=2) return json.dumps(ocr_data, indent=2)
except json.JSONDecodeError: except json.JSONDecodeError:
return value return value
metadata = "\n".join( fts_metadata = "\n".join(
[ [
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}" f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
for entry in target.metadata_entries for entry in target.metadata_entries
@ -462,7 +462,7 @@ async def update_fts_and_vec(mapper, connection, target):
) )
# Update FTS table # Update FTS table
update_or_insert_entities_fts(session, target.id, target.filepath, tags, metadata) update_or_insert_entities_fts(session, target.id, target.filepath, tags, fts_metadata)
# Prepare vector data # Prepare vector data
metadata_text = "\n".join( metadata_text = "\n".join(
@ -473,6 +473,14 @@ async def update_fts_and_vec(mapper, connection, target):
] ]
) )
# Add ocr_result at the end of metadata_text using process_ocr_result
ocr_result = next(
(entry.value for entry in target.metadata_entries if entry.key == "ocr_result"),
""
)
processed_ocr_result = process_ocr_result(ocr_result, max_length=128)
metadata_text += f"\nocr_result: {processed_ocr_result}"
# Use the new get_embeddings function # Use the new get_embeddings function
embeddings = await get_embeddings([metadata_text]) embeddings = await get_embeddings([metadata_text])
if not embeddings: if not embeddings:
@ -512,4 +520,4 @@ def update_fts_and_vec_sync(mapper, connection, target):
# Replace the old event listener with the new sync version # Replace the old event listener with the new sync version
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync) event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
event.listen(EntityModel, "after_update", update_fts_and_vec_sync) event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
event.listen(EntityModel, "after_delete", delete_fts_and_vec) event.listen(EntityModel, "after_delete", delete_fts_and_vec)