mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(model): add ocr to embedding table
This commit is contained in:
parent
29d1e2f0ce
commit
a1a5e576c0
@ -438,7 +438,7 @@ async def update_fts_and_vec(mapper, connection, target):
|
||||
tags = ", ".join([tag.name for tag in target.tags])
|
||||
|
||||
# Process metadata entries
|
||||
def process_ocr_result(value):
|
||||
def process_ocr_result(value, max_length=4096):
|
||||
try:
|
||||
ocr_data = json.loads(value)
|
||||
if isinstance(ocr_data, list) and all(
|
||||
@ -448,13 +448,13 @@ async def update_fts_and_vec(mapper, connection, target):
|
||||
and "score" in item
|
||||
for item in ocr_data
|
||||
):
|
||||
return " ".join(item["rec_txt"] for item in ocr_data)
|
||||
return " ".join(item["rec_txt"] for item in ocr_data[:max_length])
|
||||
else:
|
||||
return json.dumps(ocr_data, indent=2)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
metadata = "\n".join(
|
||||
fts_metadata = "\n".join(
|
||||
[
|
||||
f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
|
||||
for entry in target.metadata_entries
|
||||
@ -462,7 +462,7 @@ async def update_fts_and_vec(mapper, connection, target):
|
||||
)
|
||||
|
||||
# Update FTS table
|
||||
update_or_insert_entities_fts(session, target.id, target.filepath, tags, metadata)
|
||||
update_or_insert_entities_fts(session, target.id, target.filepath, tags, fts_metadata)
|
||||
|
||||
# Prepare vector data
|
||||
metadata_text = "\n".join(
|
||||
@ -473,6 +473,14 @@ async def update_fts_and_vec(mapper, connection, target):
|
||||
]
|
||||
)
|
||||
|
||||
# Add ocr_result at the end of metadata_text using process_ocr_result
|
||||
ocr_result = next(
|
||||
(entry.value for entry in target.metadata_entries if entry.key == "ocr_result"),
|
||||
""
|
||||
)
|
||||
processed_ocr_result = process_ocr_result(ocr_result, max_length=128)
|
||||
metadata_text += f"\nocr_result: {processed_ocr_result}"
|
||||
|
||||
# Use the new get_embeddings function
|
||||
embeddings = await get_embeddings([metadata_text])
|
||||
if not embeddings:
|
||||
@ -512,4 +520,4 @@ def update_fts_and_vec_sync(mapper, connection, target):
|
||||
# Replace the old event listener with the new sync version
|
||||
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
|
||||
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
|
||||
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
|
||||
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
|
||||
|
Loading…
x
Reference in New Issue
Block a user