refact: remove vec and fts when delete entity in crud

This commit is contained in:
arkohut 2024-11-06 17:37:10 +08:00
parent 6450d20e32
commit 22c2158a58
4 changed files with 162 additions and 15 deletions

View File

@ -184,6 +184,17 @@ def get_entities_by_filepaths(filepaths: List[str], db: Session) -> List[Entity]
def remove_entity(entity_id: int, db: Session):
entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
if entity:
# Delete the entity from FTS and vec tables first
db.execute(
text("DELETE FROM entities_fts WHERE id = :id"),
{"id": entity_id}
)
db.execute(
text("DELETE FROM entities_vec WHERE rowid = :id"),
{"id": entity_id}
)
# Then delete the entity itself
db.delete(entity)
db.commit()
else:

View File

@ -478,4 +478,3 @@ def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
# Add event listeners for EntityModel
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
event.listen(EntityModel, "after_delete", delete_fts_and_vec)

View File

@ -12,8 +12,8 @@ from enum import Enum
class FolderType(Enum):
DEFAULT = "default"
DUMMY = "dummy"
DEFAULT = "DEFAULT"
DUMMY = "DUMMY"
class MetadataSource(Enum):

View File

@ -1,9 +1,10 @@
import json
import os
import pytest
from datetime import datetime
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy import create_engine, event, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from pathlib import Path
@ -16,12 +17,15 @@ from memos.schemas import (
NewEntityParam,
UpdateEntityParam,
NewFoldersParam,
NewFolderParam,
EntityMetadataParam,
MetadataType,
UpdateEntityTagsParam,
UpdateEntityMetadataParam,
FolderType,
)
from memos.models import Base
from memos.models import Base, load_extension
from memos.config import settings
engine = create_engine(
@ -29,6 +33,10 @@ engine = create_engine(
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
# 添加扩展加载事件监听器
event.listen(engine, "connect", load_extension)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@ -47,7 +55,13 @@ def setup_library_with_entity(client):
library_id = library_response.json()["id"]
# Create a new folder in the library
new_folder = NewFoldersParam(folders=["/tmp"])
new_folder = NewFoldersParam(
folders=[
NewFolderParam(
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
)
]
)
folder_response = client.post(
f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")
)
@ -88,10 +102,45 @@ app.dependency_overrides[get_db] = override_get_db
# Setup a fixture for the FastAPI test client
@pytest.fixture
def client():
# 创建所有基本表
Base.metadata.create_all(bind=engine)
# 创建 FTS 和 Vec 表
with engine.connect() as conn:
# 创建 FTS 表
conn.execute(
text(
"""
CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5(
id, filepath, tags, metadata,
tokenize = 'simple 0'
)
"""
)
)
# 创建 Vec 表
conn.execute(
text(
f"""
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
embedding float[{settings.embedding.num_dim}]
)
"""
)
)
conn.commit()
with TestClient(app) as client:
yield client
# 清理数据库
Base.metadata.drop_all(bind=engine)
with engine.connect() as conn:
conn.execute(text("DROP TABLE IF EXISTS entities_fts"))
conn.execute(text("DROP TABLE IF EXISTS entities_vec"))
conn.commit()
# Test the new_library endpoint
@ -119,8 +168,15 @@ def test_new_library(client):
def test_list_libraries(client):
# Setup data: Create a new library
new_library = NewLibraryParam(name="Sample Library", folders=["/tmp"])
# Setup data: Create a new library with a folder
new_library = NewLibraryParam(
name="Sample Library",
folders=[
NewFolderParam(
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
)
],
)
client.post("/libraries", json=new_library.model_dump(mode="json"))
# Make a GET request to the /libraries endpoint
@ -130,20 +186,39 @@ def test_list_libraries(client):
assert response.status_code == 200
# Check the response data
response_data = response.json()
for folder in response_data[0]["folders"]:
assert "last_modified_at" in folder
assert isinstance(folder["last_modified_at"], str)
del folder["last_modified_at"]
expected_data = [
{
"id": 1,
"name": "Sample Library",
"folders": [{"id": 1, "path": "/tmp"}],
"folders": [
{
"id": 1,
"path": "/tmp",
"type": "DEFAULT",
}
],
"plugins": [],
}
]
assert response.json() == expected_data
assert response_data == expected_data
def test_new_entity(client):
# Setup data: Create a new library
new_library = NewLibraryParam(name="Library for Entity Test", folders=["/tmp"])
new_library = NewLibraryParam(
name="Library for Entity Test",
folders=[
NewFolderParam(
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
)
],
)
library_response = client.post(
"/libraries", json=new_library.model_dump(mode="json")
)
@ -226,7 +301,14 @@ def test_update_entity(client):
# Test for getting an entity by filepath
def test_get_entity_by_filepath(client):
# Setup data: Create a new library and entity
new_library = NewLibraryParam(name="Library for Get Entity Test", folders=["/tmp"])
new_library = NewLibraryParam(
name="Library for Get Entity Test",
folders=[
NewFolderParam(
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
)
],
)
library_response = client.post(
"/libraries", json=new_library.model_dump(mode="json")
)
@ -289,7 +371,13 @@ def test_list_entities_in_folder(client):
)
library_id = library_response.json()["id"]
new_folder = NewFoldersParam(folders=["/tmp"])
new_folder = NewFoldersParam(
folders=[
NewFolderParam(
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
)
]
)
folder_response = client.post(
f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")
)
@ -343,15 +431,45 @@ def test_list_entities_in_folder(client):
def test_remove_entity(client):
library_id, _, entity_id = setup_library_with_entity(client)
# Verify the entity data was automatically inserted into fts and vec tables by event listeners
with engine.connect() as conn:
fts_count = conn.execute(
text("SELECT COUNT(*) FROM entities_fts WHERE id = :id"),
{"id": entity_id}
).scalar()
assert fts_count == 1, "Entity was not automatically added to entities_fts table"
vec_count = conn.execute(
text("SELECT COUNT(*) FROM entities_vec WHERE rowid = :id"),
{"id": entity_id}
).scalar()
assert vec_count == 1, "Entity was not automatically added to entities_vec table"
# Delete the entity
delete_response = client.delete(f"/libraries/{library_id}/entities/{entity_id}")
assert delete_response.status_code == 204
# Verify the entity is deleted
# Verify the entity is deleted from the main table
get_response = client.get(f"/libraries/{library_id}/entities/{entity_id}")
assert get_response.status_code == 404
assert get_response.json() == {"detail": "Entity not found"}
# Verify the entity is deleted from entities_fts and entities_vec tables
with engine.connect() as conn:
# Check entities_fts
fts_count = conn.execute(
text("SELECT COUNT(*) FROM entities_fts WHERE id = :id"),
{"id": entity_id}
).scalar()
assert fts_count == 0, "Entity was not deleted from entities_fts table"
# Check entities_vec
vec_count = conn.execute(
text("SELECT COUNT(*) FROM entities_vec WHERE rowid = :id"),
{"id": entity_id}
).scalar()
assert vec_count == 0, "Entity was not deleted from entities_vec table"
# Test for entity not found in the specified library
invalid_delete_response = client.delete(f"/libraries/{library_id}/entities/9999")
assert invalid_delete_response.status_code == 404
@ -374,7 +492,15 @@ def test_add_folder_to_library(client):
library_id = library_response.json()["id"]
# Add a new folder to the library
new_folders = NewFoldersParam(folders=[tmp_folder_path])
new_folders = NewFoldersParam(
folders=[
NewFolderParam(
path=tmp_folder_path,
last_modified_at=datetime.now(),
type=FolderType.DEFAULT,
)
]
)
folder_response = client.post(
f"/libraries/{library_id}/folders", json=new_folders.model_dump(mode="json")
)
@ -615,6 +741,17 @@ def test_patch_entity_metadata_entries(client):
# Check the response data
patched_entity_data = patch_response.json()
expected_data = load_fixture("patch_entity_metadata_response.json")
# 检查并移除 last_scan_at
assert "last_scan_at" in patched_entity_data
assert isinstance(patched_entity_data["last_scan_at"], str)
datetime.fromisoformat(patched_entity_data["last_scan_at"].replace("Z", "+00:00"))
del patched_entity_data["last_scan_at"]
if "last_scan_at" in expected_data:
del expected_data["last_scan_at"]
assert patched_entity_data == expected_data
# Update the "author" attribute of the entity