mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
refact: remove vec and fts when delete entity in crud
This commit is contained in:
parent
6450d20e32
commit
22c2158a58
@ -184,6 +184,17 @@ def get_entities_by_filepaths(filepaths: List[str], db: Session) -> List[Entity]
|
||||
def remove_entity(entity_id: int, db: Session):
|
||||
entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
||||
if entity:
|
||||
# Delete the entity from FTS and vec tables first
|
||||
db.execute(
|
||||
text("DELETE FROM entities_fts WHERE id = :id"),
|
||||
{"id": entity_id}
|
||||
)
|
||||
db.execute(
|
||||
text("DELETE FROM entities_vec WHERE rowid = :id"),
|
||||
{"id": entity_id}
|
||||
)
|
||||
|
||||
# Then delete the entity itself
|
||||
db.delete(entity)
|
||||
db.commit()
|
||||
else:
|
||||
|
@ -478,4 +478,3 @@ def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
|
||||
# Add event listeners for EntityModel
|
||||
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
|
||||
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
|
||||
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
|
@ -12,8 +12,8 @@ from enum import Enum
|
||||
|
||||
|
||||
class FolderType(Enum):
|
||||
DEFAULT = "default"
|
||||
DUMMY = "dummy"
|
||||
DEFAULT = "DEFAULT"
|
||||
DUMMY = "DUMMY"
|
||||
|
||||
|
||||
class MetadataSource(Enum):
|
||||
|
@ -1,9 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from pathlib import Path
|
||||
@ -16,12 +17,15 @@ from memos.schemas import (
|
||||
NewEntityParam,
|
||||
UpdateEntityParam,
|
||||
NewFoldersParam,
|
||||
NewFolderParam,
|
||||
EntityMetadataParam,
|
||||
MetadataType,
|
||||
UpdateEntityTagsParam,
|
||||
UpdateEntityMetadataParam,
|
||||
FolderType,
|
||||
)
|
||||
from memos.models import Base
|
||||
from memos.models import Base, load_extension
|
||||
from memos.config import settings
|
||||
|
||||
|
||||
engine = create_engine(
|
||||
@ -29,6 +33,10 @@ engine = create_engine(
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
# 添加扩展加载事件监听器
|
||||
event.listen(engine, "connect", load_extension)
|
||||
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@ -47,7 +55,13 @@ def setup_library_with_entity(client):
|
||||
library_id = library_response.json()["id"]
|
||||
|
||||
# Create a new folder in the library
|
||||
new_folder = NewFoldersParam(folders=["/tmp"])
|
||||
new_folder = NewFoldersParam(
|
||||
folders=[
|
||||
NewFolderParam(
|
||||
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||
)
|
||||
]
|
||||
)
|
||||
folder_response = client.post(
|
||||
f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")
|
||||
)
|
||||
@ -88,10 +102,45 @@ app.dependency_overrides[get_db] = override_get_db
|
||||
# Setup a fixture for the FastAPI test client
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# 创建所有基本表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# 创建 FTS 和 Vec 表
|
||||
with engine.connect() as conn:
|
||||
# 创建 FTS 表
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5(
|
||||
id, filepath, tags, metadata,
|
||||
tokenize = 'simple 0'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 创建 Vec 表
|
||||
conn.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
|
||||
embedding float[{settings.embedding.num_dim}]
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
# 清理数据库
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("DROP TABLE IF EXISTS entities_fts"))
|
||||
conn.execute(text("DROP TABLE IF EXISTS entities_vec"))
|
||||
conn.commit()
|
||||
|
||||
|
||||
# Test the new_library endpoint
|
||||
@ -119,8 +168,15 @@ def test_new_library(client):
|
||||
|
||||
|
||||
def test_list_libraries(client):
|
||||
# Setup data: Create a new library
|
||||
new_library = NewLibraryParam(name="Sample Library", folders=["/tmp"])
|
||||
# Setup data: Create a new library with a folder
|
||||
new_library = NewLibraryParam(
|
||||
name="Sample Library",
|
||||
folders=[
|
||||
NewFolderParam(
|
||||
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||
)
|
||||
],
|
||||
)
|
||||
client.post("/libraries", json=new_library.model_dump(mode="json"))
|
||||
|
||||
# Make a GET request to the /libraries endpoint
|
||||
@ -130,20 +186,39 @@ def test_list_libraries(client):
|
||||
assert response.status_code == 200
|
||||
|
||||
# Check the response data
|
||||
response_data = response.json()
|
||||
for folder in response_data[0]["folders"]:
|
||||
assert "last_modified_at" in folder
|
||||
assert isinstance(folder["last_modified_at"], str)
|
||||
del folder["last_modified_at"]
|
||||
|
||||
expected_data = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Sample Library",
|
||||
"folders": [{"id": 1, "path": "/tmp"}],
|
||||
"folders": [
|
||||
{
|
||||
"id": 1,
|
||||
"path": "/tmp",
|
||||
"type": "DEFAULT",
|
||||
}
|
||||
],
|
||||
"plugins": [],
|
||||
}
|
||||
]
|
||||
assert response.json() == expected_data
|
||||
assert response_data == expected_data
|
||||
|
||||
|
||||
def test_new_entity(client):
|
||||
# Setup data: Create a new library
|
||||
new_library = NewLibraryParam(name="Library for Entity Test", folders=["/tmp"])
|
||||
new_library = NewLibraryParam(
|
||||
name="Library for Entity Test",
|
||||
folders=[
|
||||
NewFolderParam(
|
||||
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||
)
|
||||
],
|
||||
)
|
||||
library_response = client.post(
|
||||
"/libraries", json=new_library.model_dump(mode="json")
|
||||
)
|
||||
@ -226,7 +301,14 @@ def test_update_entity(client):
|
||||
# Test for getting an entity by filepath
|
||||
def test_get_entity_by_filepath(client):
|
||||
# Setup data: Create a new library and entity
|
||||
new_library = NewLibraryParam(name="Library for Get Entity Test", folders=["/tmp"])
|
||||
new_library = NewLibraryParam(
|
||||
name="Library for Get Entity Test",
|
||||
folders=[
|
||||
NewFolderParam(
|
||||
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||
)
|
||||
],
|
||||
)
|
||||
library_response = client.post(
|
||||
"/libraries", json=new_library.model_dump(mode="json")
|
||||
)
|
||||
@ -289,7 +371,13 @@ def test_list_entities_in_folder(client):
|
||||
)
|
||||
library_id = library_response.json()["id"]
|
||||
|
||||
new_folder = NewFoldersParam(folders=["/tmp"])
|
||||
new_folder = NewFoldersParam(
|
||||
folders=[
|
||||
NewFolderParam(
|
||||
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||
)
|
||||
]
|
||||
)
|
||||
folder_response = client.post(
|
||||
f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")
|
||||
)
|
||||
@ -343,15 +431,45 @@ def test_list_entities_in_folder(client):
|
||||
def test_remove_entity(client):
|
||||
library_id, _, entity_id = setup_library_with_entity(client)
|
||||
|
||||
# Verify the entity data was automatically inserted into fts and vec tables by event listeners
|
||||
with engine.connect() as conn:
|
||||
fts_count = conn.execute(
|
||||
text("SELECT COUNT(*) FROM entities_fts WHERE id = :id"),
|
||||
{"id": entity_id}
|
||||
).scalar()
|
||||
assert fts_count == 1, "Entity was not automatically added to entities_fts table"
|
||||
|
||||
vec_count = conn.execute(
|
||||
text("SELECT COUNT(*) FROM entities_vec WHERE rowid = :id"),
|
||||
{"id": entity_id}
|
||||
).scalar()
|
||||
assert vec_count == 1, "Entity was not automatically added to entities_vec table"
|
||||
|
||||
# Delete the entity
|
||||
delete_response = client.delete(f"/libraries/{library_id}/entities/{entity_id}")
|
||||
assert delete_response.status_code == 204
|
||||
|
||||
# Verify the entity is deleted
|
||||
# Verify the entity is deleted from the main table
|
||||
get_response = client.get(f"/libraries/{library_id}/entities/{entity_id}")
|
||||
assert get_response.status_code == 404
|
||||
assert get_response.json() == {"detail": "Entity not found"}
|
||||
|
||||
# Verify the entity is deleted from entities_fts and entities_vec tables
|
||||
with engine.connect() as conn:
|
||||
# Check entities_fts
|
||||
fts_count = conn.execute(
|
||||
text("SELECT COUNT(*) FROM entities_fts WHERE id = :id"),
|
||||
{"id": entity_id}
|
||||
).scalar()
|
||||
assert fts_count == 0, "Entity was not deleted from entities_fts table"
|
||||
|
||||
# Check entities_vec
|
||||
vec_count = conn.execute(
|
||||
text("SELECT COUNT(*) FROM entities_vec WHERE rowid = :id"),
|
||||
{"id": entity_id}
|
||||
).scalar()
|
||||
assert vec_count == 0, "Entity was not deleted from entities_vec table"
|
||||
|
||||
# Test for entity not found in the specified library
|
||||
invalid_delete_response = client.delete(f"/libraries/{library_id}/entities/9999")
|
||||
assert invalid_delete_response.status_code == 404
|
||||
@ -374,7 +492,15 @@ def test_add_folder_to_library(client):
|
||||
library_id = library_response.json()["id"]
|
||||
|
||||
# Add a new folder to the library
|
||||
new_folders = NewFoldersParam(folders=[tmp_folder_path])
|
||||
new_folders = NewFoldersParam(
|
||||
folders=[
|
||||
NewFolderParam(
|
||||
path=tmp_folder_path,
|
||||
last_modified_at=datetime.now(),
|
||||
type=FolderType.DEFAULT,
|
||||
)
|
||||
]
|
||||
)
|
||||
folder_response = client.post(
|
||||
f"/libraries/{library_id}/folders", json=new_folders.model_dump(mode="json")
|
||||
)
|
||||
@ -615,6 +741,17 @@ def test_patch_entity_metadata_entries(client):
|
||||
# Check the response data
|
||||
patched_entity_data = patch_response.json()
|
||||
expected_data = load_fixture("patch_entity_metadata_response.json")
|
||||
|
||||
# 检查并移除 last_scan_at
|
||||
assert "last_scan_at" in patched_entity_data
|
||||
assert isinstance(patched_entity_data["last_scan_at"], str)
|
||||
|
||||
datetime.fromisoformat(patched_entity_data["last_scan_at"].replace("Z", "+00:00"))
|
||||
|
||||
del patched_entity_data["last_scan_at"]
|
||||
if "last_scan_at" in expected_data:
|
||||
del expected_data["last_scan_at"]
|
||||
|
||||
assert patched_entity_data == expected_data
|
||||
|
||||
# Update the "author" attribute of the entity
|
||||
|
Loading…
x
Reference in New Issue
Block a user