From a1a5e576c098260101780036f709b59f457a4429 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:49:30 +0800 Subject: [PATCH] feat(model): add ocr to embedding table --- memos/models.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/memos/models.py b/memos/models.py index 30f5f3e..3a9b951 100644 --- a/memos/models.py +++ b/memos/models.py @@ -438,7 +438,7 @@ async def update_fts_and_vec(mapper, connection, target): tags = ", ".join([tag.name for tag in target.tags]) # Process metadata entries - def process_ocr_result(value): + def process_ocr_result(value, max_length=4096): try: ocr_data = json.loads(value) if isinstance(ocr_data, list) and all( @@ -448,13 +448,13 @@ async def update_fts_and_vec(mapper, connection, target): and "score" in item for item in ocr_data ): - return " ".join(item["rec_txt"] for item in ocr_data) + return " ".join(item["rec_txt"] for item in ocr_data[:max_length]) else: return json.dumps(ocr_data, indent=2) except json.JSONDecodeError: return value - metadata = "\n".join( + fts_metadata = "\n".join( [ f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}" for entry in target.metadata_entries @@ -462,7 +462,7 @@ async def update_fts_and_vec(mapper, connection, target): ) # Update FTS table - update_or_insert_entities_fts(session, target.id, target.filepath, tags, metadata) + update_or_insert_entities_fts(session, target.id, target.filepath, tags, fts_metadata) # Prepare vector data metadata_text = "\n".join( @@ -473,6 +473,14 @@ async def update_fts_and_vec(mapper, connection, target): ] ) + # Add ocr_result at the end of metadata_text using process_ocr_result + ocr_result = next( + (entry.value for entry in target.metadata_entries if entry.key == "ocr_result"), + "" + ) + processed_ocr_result = process_ocr_result(ocr_result, max_length=128) + metadata_text += f"\nocr_result: {processed_ocr_result}" + # Use the new get_embeddings function embeddings = await get_embeddings([metadata_text]) if not embeddings: @@ -512,4 +520,4 @@ def update_fts_and_vec_sync(mapper, connection, target): # Replace the old event listener with the new sync version event.listen(EntityModel, "after_insert", update_fts_and_vec_sync) event.listen(EntityModel, "after_update", update_fts_and_vec_sync) -event.listen(EntityModel, "after_delete", delete_fts_and_vec) \ No newline at end of file +event.listen(EntityModel, "after_delete", delete_fts_and_vec)