diff --git a/memos/models.py b/memos/models.py index 0c39d27..51782d9 100644 --- a/memos/models.py +++ b/memos/models.py @@ -326,10 +326,12 @@ def init_default_libraries(session, default_plugins): bind_response = session.query(PluginModel).filter_by(name=plugin.name).first() if bind_response: # Check if the LibraryPluginModel already exists - existing_library_plugin = session.query(LibraryPluginModel).filter_by( - library_id=1, plugin_id=bind_response.id - ).first() - + existing_library_plugin = ( + session.query(LibraryPluginModel) + .filter_by(library_id=1, plugin_id=bind_response.id) + .first() + ) + if not existing_library_plugin: library_plugin = LibraryPluginModel( library_id=1, plugin_id=bind_response.id @@ -341,27 +343,23 @@ def init_default_libraries(session, default_plugins): async def update_or_insert_entities_vec(session, target_id, embedding): try: - # First, try to update the existing row - result = session.execute( - text("UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"), + session.execute( + text("DELETE FROM entities_vec WHERE rowid = :id"), + {"id": target_id} + ) + + session.execute( + text( + """ + INSERT INTO entities_vec (rowid, embedding) + VALUES (:id, :embedding) + """ + ), { "id": target_id, "embedding": serialize_float32(embedding), }, ) - - # If no row was updated (i.e., the row doesn't exist), then insert a new row - if result.rowcount == 0: - session.execute( - text( - "INSERT INTO entities_vec (rowid, embedding) VALUES (:id, :embedding)" - ), - { - "id": target_id, - "embedding": serialize_float32(embedding), - }, - ) - session.commit() except Exception as e: print(f"Error updating entities_vec: {e}") @@ -370,13 +368,11 @@ async def update_or_insert_entities_vec(session, target_id, embedding): def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata): try: - # First, try to update the existing row - result = session.execute( + session.execute( text( """ - UPDATE entities_fts - SET filepath = :filepath, tags = :tags, metadata = :metadata - WHERE id = :id + INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata) + VALUES(:id, :filepath, :tags, :metadata) """ ), { @@ -386,35 +382,17 @@ def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata): "metadata": metadata, }, ) - - # If no row was updated (i.e., the row doesn't exist), then insert a new row - if result.rowcount == 0: - session.execute( - text( - """ - INSERT INTO entities_fts(id, filepath, tags, metadata) - VALUES(:id, :filepath, :tags, :metadata) - """ - ), - { - "id": target_id, - "filepath": filepath, - "tags": tags, - "metadata": metadata, - }, - ) - session.commit() except Exception as e: print(f"Error updating entities_fts: {e}") session.rollback() -async def update_fts_and_vec(mapper, connection, target): +async def update_fts_and_vec(mapper, connection, entity: EntityModel): session = Session(bind=connection) # Prepare FTS data - tags = ", ".join([tag.name for tag in target.tags]) + tags = ", ".join([tag.name for tag in entity.tags]) # Process metadata entries def process_ocr_result(value, max_length=4096): @@ -436,26 +414,28 @@ async def update_fts_and_vec(mapper, connection, target): fts_metadata = "\n".join( [ f"{entry.key}: {process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}" - for entry in target.metadata_entries + for entry in entity.metadata_entries ] ) # Update FTS table - update_or_insert_entities_fts(session, target.id, target.filepath, tags, fts_metadata) + update_or_insert_entities_fts( + session, entity.id, entity.filepath, tags, fts_metadata + ) # Prepare vector data metadata_text = "\n".join( [ f"{entry.key}: {entry.value}" - for entry in target.metadata_entries + for entry in entity.metadata_entries if entry.key != "ocr_result" ] ) # Add ocr_result at the end of metadata_text using process_ocr_result ocr_result = next( - (entry.value for entry in target.metadata_entries if entry.key == "ocr_result"), - "" + (entry.value for entry in entity.metadata_entries if entry.key == "ocr_result"), + "", ) processed_ocr_result = process_ocr_result(ocr_result, max_length=128) metadata_text += f"\nocr_result: {processed_ocr_result}" @@ -469,15 +449,15 @@ async def update_fts_and_vec(mapper, connection, target): # Update vector table if embedding: - await update_or_insert_entities_vec(session, target.id, embedding) + await update_or_insert_entities_vec(session, entity.id, embedding) -def delete_fts_and_vec(mapper, connection, target): +def delete_fts_and_vec(mapper, connection, entity: EntityModel): connection.execute( - text("DELETE FROM entities_fts WHERE id = :id"), {"id": target.id} + text("DELETE FROM entities_fts WHERE id = :id"), {"id": entity.id} ) connection.execute( - text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": target.id} + text("DELETE FROM entities_vec WHERE rowid = :id"), {"id": entity.id} ) @@ -487,9 +467,9 @@ def run_async(coro): return loop.run_until_complete(coro) -def update_fts_and_vec_sync(mapper, connection, target): +def update_fts_and_vec_sync(mapper, connection, entity: EntityModel): def run_in_thread(): - run_async(update_fts_and_vec(mapper, connection, target)) + run_async(update_fts_and_vec(mapper, connection, entity)) thread = threading.Thread(target=run_in_thread) thread.start()