feat(model): add ocr to embedding table

This commit is contained in:
arkohut 2024-10-21 14:49:30 +08:00
parent 29d1e2f0ce
commit a1a5e576c0

View File

@ -438,7 +438,7 @@ async def update_fts_and_vec(mapper, connection, target):
tags = ", ".join([tag.name for tag in target.tags])
# Process metadata entries
def process_ocr_result(value):
def process_ocr_result(value, max_length=4096):
try:
ocr_data = json.loads(value)
if isinstance(ocr_data, list) and all(
@ -448,13 +448,13 @@ async def update_fts_and_vec(mapper, connection, target):
and "score" in item
for item in ocr_data
):
return " ".join(item["rec_txt"] for item in ocr_data)
return " ".join(item["rec_txt"] for item in ocr_data[:max_length])
else:
return json.dumps(ocr_data, indent=2)
except json.JSONDecodeError:
return value
metadata = "\n".join(
fts_metadata = "\n".join(
[
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
for entry in target.metadata_entries
@ -462,7 +462,7 @@ async def update_fts_and_vec(mapper, connection, target):
)
# Update FTS table
update_or_insert_entities_fts(session, target.id, target.filepath, tags, metadata)
update_or_insert_entities_fts(session, target.id, target.filepath, tags, fts_metadata)
# Prepare vector data
metadata_text = "\n".join(
@ -473,6 +473,14 @@ async def update_fts_and_vec(mapper, connection, target):
]
)
# Add ocr_result at the end of metadata_text using process_ocr_result
ocr_result = next(
(entry.value for entry in target.metadata_entries if entry.key == "ocr_result"),
""
)
processed_ocr_result = process_ocr_result(ocr_result, max_length=128)
metadata_text += f"\nocr_result: {processed_ocr_result}"
# Use the new get_embeddings function
embeddings = await get_embeddings([metadata_text])
if not embeddings:
@ -512,4 +520,4 @@ def update_fts_and_vec_sync(mapper, connection, target):
# Replace the old event listener with the new sync version
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
event.listen(EntityModel, "after_delete", delete_fts_and_vec)