From 22c2158a58dc24cf0b0d7175a067c6c3a4adb717 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:37:10 +0800 Subject: [PATCH] refact: remove vec and fts when delete entity in crud --- memos/crud.py | 11 +++ memos/models.py | 1 - memos/schemas.py | 4 +- memos/test_server.py | 161 +++++++++++++++++++++++++++++++++++++++---- 4 files changed, 162 insertions(+), 15 deletions(-) diff --git a/memos/crud.py b/memos/crud.py index 0b8186c..9e35954 100644 --- a/memos/crud.py +++ b/memos/crud.py @@ -184,6 +184,17 @@ def get_entities_by_filepaths(filepaths: List[str], db: Session) -> List[Entity] def remove_entity(entity_id: int, db: Session): entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first() if entity: + # Delete the entity from FTS and vec tables first + db.execute( + text("DELETE FROM entities_fts WHERE id = :id"), + {"id": entity_id} + ) + db.execute( + text("DELETE FROM entities_vec WHERE rowid = :id"), + {"id": entity_id} + ) + + # Then delete the entity itself db.delete(entity) db.commit() else: diff --git a/memos/models.py b/memos/models.py index e5b15e5..0205195 100644 --- a/memos/models.py +++ b/memos/models.py @@ -478,4 +478,3 @@ def update_fts_and_vec_sync(mapper, connection, entity: EntityModel): # Add event listeners for EntityModel event.listen(EntityModel, "after_insert", update_fts_and_vec_sync) event.listen(EntityModel, "after_update", update_fts_and_vec_sync) -event.listen(EntityModel, "after_delete", delete_fts_and_vec) \ No newline at end of file diff --git a/memos/schemas.py b/memos/schemas.py index 069a1cf..d544d0f 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -12,8 +12,8 @@ from enum import Enum class FolderType(Enum): - DEFAULT = "default" - DUMMY = "dummy" + DEFAULT = "DEFAULT" + DUMMY = "DUMMY" class MetadataSource(Enum): diff --git a/memos/test_server.py b/memos/test_server.py index 1afbe8e..4c9c9b3 100644 --- a/memos/test_server.py +++ b/memos/test_server.py @@ -1,9 +1,10 @@ import json import os import pytest +from datetime import datetime from fastapi.testclient import TestClient -from sqlalchemy import create_engine +from sqlalchemy import create_engine, event, text from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from pathlib import Path @@ -16,12 +17,15 @@ from memos.schemas import ( NewEntityParam, UpdateEntityParam, NewFoldersParam, + NewFolderParam, EntityMetadataParam, MetadataType, UpdateEntityTagsParam, UpdateEntityMetadataParam, + FolderType, ) -from memos.models import Base +from memos.models import Base, load_extension +from memos.config import settings engine = create_engine( @@ -29,6 +33,10 @@ engine = create_engine( connect_args={"check_same_thread": False}, poolclass=StaticPool, ) + +# 添加扩展加载事件监听器 +event.listen(engine, "connect", load_extension) + TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -47,7 +55,13 @@ def setup_library_with_entity(client): library_id = library_response.json()["id"] # Create a new folder in the library - new_folder = NewFoldersParam(folders=["/tmp"]) + new_folder = NewFoldersParam( + folders=[ + NewFolderParam( + path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT + ) + ] + ) folder_response = client.post( f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json") ) @@ -88,10 +102,45 @@ app.dependency_overrides[get_db] = override_get_db # Setup a fixture for the FastAPI test client @pytest.fixture def client(): + # 创建所有基本表 Base.metadata.create_all(bind=engine) + + # 创建 FTS 和 Vec 表 + with engine.connect() as conn: + # 创建 FTS 表 + conn.execute( + text( + """ + CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5( + id, filepath, tags, metadata, + tokenize = 'simple 0' + ) + """ + ) + ) + + # 创建 Vec 表 + conn.execute( + text( + f""" + CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0( + embedding float[{settings.embedding.num_dim}] + ) + """ + ) + ) + + conn.commit() + with TestClient(app) as client: yield client + + # 清理数据库 Base.metadata.drop_all(bind=engine) + with engine.connect() as conn: + conn.execute(text("DROP TABLE IF EXISTS entities_fts")) + conn.execute(text("DROP TABLE IF EXISTS entities_vec")) + conn.commit() # Test the new_library endpoint @@ -119,8 +168,15 @@ def test_new_library(client): def test_list_libraries(client): - # Setup data: Create a new library - new_library = NewLibraryParam(name="Sample Library", folders=["/tmp"]) + # Setup data: Create a new library with a folder + new_library = NewLibraryParam( + name="Sample Library", + folders=[ + NewFolderParam( + path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT + ) + ], + ) client.post("/libraries", json=new_library.model_dump(mode="json")) # Make a GET request to the /libraries endpoint @@ -130,20 +186,39 @@ def test_list_libraries(client): assert response.status_code == 200 # Check the response data + response_data = response.json() + for folder in response_data[0]["folders"]: + assert "last_modified_at" in folder + assert isinstance(folder["last_modified_at"], str) + del folder["last_modified_at"] + expected_data = [ { "id": 1, "name": "Sample Library", - "folders": [{"id": 1, "path": "/tmp"}], + "folders": [ + { + "id": 1, + "path": "/tmp", + "type": "DEFAULT", + } + ], "plugins": [], } ] - assert response.json() == expected_data + assert response_data == expected_data def test_new_entity(client): # Setup data: Create a new library - new_library = NewLibraryParam(name="Library for Entity Test", folders=["/tmp"]) + new_library = NewLibraryParam( + name="Library for Entity Test", + folders=[ + NewFolderParam( + path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT + ) + ], + ) library_response = client.post( "/libraries", json=new_library.model_dump(mode="json") ) @@ -226,7 +301,14 @@ def test_update_entity(client): # Test for getting an entity by filepath def test_get_entity_by_filepath(client): # Setup data: Create a new library and entity - new_library = NewLibraryParam(name="Library for Get Entity Test", folders=["/tmp"]) + new_library = NewLibraryParam( + name="Library for Get Entity Test", + folders=[ + NewFolderParam( + path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT + ) + ], + ) library_response = client.post( "/libraries", json=new_library.model_dump(mode="json") ) @@ -289,7 +371,13 @@ def test_list_entities_in_folder(client): ) library_id = library_response.json()["id"] - new_folder = NewFoldersParam(folders=["/tmp"]) + new_folder = NewFoldersParam( + folders=[ + NewFolderParam( + path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT + ) + ] + ) folder_response = client.post( f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json") ) @@ -343,15 +431,45 @@ def test_list_entities_in_folder(client): def test_remove_entity(client): library_id, _, entity_id = setup_library_with_entity(client) + # Verify the entity data was automatically inserted into fts and vec tables by event listeners + with engine.connect() as conn: + fts_count = conn.execute( + text("SELECT COUNT(*) FROM entities_fts WHERE id = :id"), + {"id": entity_id} + ).scalar() + assert fts_count == 1, "Entity was not automatically added to entities_fts table" + + vec_count = conn.execute( + text("SELECT COUNT(*) FROM entities_vec WHERE rowid = :id"), + {"id": entity_id} + ).scalar() + assert vec_count == 1, "Entity was not automatically added to entities_vec table" + # Delete the entity delete_response = client.delete(f"/libraries/{library_id}/entities/{entity_id}") assert delete_response.status_code == 204 - # Verify the entity is deleted + # Verify the entity is deleted from the main table get_response = client.get(f"/libraries/{library_id}/entities/{entity_id}") assert get_response.status_code == 404 assert get_response.json() == {"detail": "Entity not found"} + # Verify the entity is deleted from entities_fts and entities_vec tables + with engine.connect() as conn: + # Check entities_fts + fts_count = conn.execute( + text("SELECT COUNT(*) FROM entities_fts WHERE id = :id"), + {"id": entity_id} + ).scalar() + assert fts_count == 0, "Entity was not deleted from entities_fts table" + + # Check entities_vec + vec_count = conn.execute( + text("SELECT COUNT(*) FROM entities_vec WHERE rowid = :id"), + {"id": entity_id} + ).scalar() + assert vec_count == 0, "Entity was not deleted from entities_vec table" + # Test for entity not found in the specified library invalid_delete_response = client.delete(f"/libraries/{library_id}/entities/9999") assert invalid_delete_response.status_code == 404 @@ -374,7 +492,15 @@ def test_add_folder_to_library(client): library_id = library_response.json()["id"] # Add a new folder to the library - new_folders = NewFoldersParam(folders=[tmp_folder_path]) + new_folders = NewFoldersParam( + folders=[ + NewFolderParam( + path=tmp_folder_path, + last_modified_at=datetime.now(), + type=FolderType.DEFAULT, + ) + ] + ) folder_response = client.post( f"/libraries/{library_id}/folders", json=new_folders.model_dump(mode="json") ) @@ -615,6 +741,17 @@ def test_patch_entity_metadata_entries(client): # Check the response data patched_entity_data = patch_response.json() expected_data = load_fixture("patch_entity_metadata_response.json") + + # 检查并移除 last_scan_at + assert "last_scan_at" in patched_entity_data + assert isinstance(patched_entity_data["last_scan_at"], str) + + datetime.fromisoformat(patched_entity_data["last_scan_at"].replace("Z", "+00:00")) + + del patched_entity_data["last_scan_at"] + if "last_scan_at" in expected_data: + del expected_data["last_scan_at"] + assert patched_entity_data == expected_data # Update the "author" attribute of the entity