mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
refact: remove vec and fts when delete entity in crud
This commit is contained in:
parent
6450d20e32
commit
22c2158a58
@ -184,6 +184,17 @@ def get_entities_by_filepaths(filepaths: List[str], db: Session) -> List[Entity]
|
|||||||
def remove_entity(entity_id: int, db: Session):
|
def remove_entity(entity_id: int, db: Session):
|
||||||
entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
|
||||||
if entity:
|
if entity:
|
||||||
|
# Delete the entity from FTS and vec tables first
|
||||||
|
db.execute(
|
||||||
|
text("DELETE FROM entities_fts WHERE id = :id"),
|
||||||
|
{"id": entity_id}
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
text("DELETE FROM entities_vec WHERE rowid = :id"),
|
||||||
|
{"id": entity_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then delete the entity itself
|
||||||
db.delete(entity)
|
db.delete(entity)
|
||||||
db.commit()
|
db.commit()
|
||||||
else:
|
else:
|
||||||
|
@ -478,4 +478,3 @@ def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
|
|||||||
# Add event listeners for EntityModel
|
# Add event listeners for EntityModel
|
||||||
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
|
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
|
||||||
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
|
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
|
||||||
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
|
|
@ -12,8 +12,8 @@ from enum import Enum
|
|||||||
|
|
||||||
|
|
||||||
class FolderType(Enum):
|
class FolderType(Enum):
|
||||||
DEFAULT = "default"
|
DEFAULT = "DEFAULT"
|
||||||
DUMMY = "dummy"
|
DUMMY = "DUMMY"
|
||||||
|
|
||||||
|
|
||||||
class MetadataSource(Enum):
|
class MetadataSource(Enum):
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, event, text
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -16,12 +17,15 @@ from memos.schemas import (
|
|||||||
NewEntityParam,
|
NewEntityParam,
|
||||||
UpdateEntityParam,
|
UpdateEntityParam,
|
||||||
NewFoldersParam,
|
NewFoldersParam,
|
||||||
|
NewFolderParam,
|
||||||
EntityMetadataParam,
|
EntityMetadataParam,
|
||||||
MetadataType,
|
MetadataType,
|
||||||
UpdateEntityTagsParam,
|
UpdateEntityTagsParam,
|
||||||
UpdateEntityMetadataParam,
|
UpdateEntityMetadataParam,
|
||||||
|
FolderType,
|
||||||
)
|
)
|
||||||
from memos.models import Base
|
from memos.models import Base, load_extension
|
||||||
|
from memos.config import settings
|
||||||
|
|
||||||
|
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
@ -29,6 +33,10 @@ engine = create_engine(
|
|||||||
connect_args={"check_same_thread": False},
|
connect_args={"check_same_thread": False},
|
||||||
poolclass=StaticPool,
|
poolclass=StaticPool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 添加扩展加载事件监听器
|
||||||
|
event.listen(engine, "connect", load_extension)
|
||||||
|
|
||||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +55,13 @@ def setup_library_with_entity(client):
|
|||||||
library_id = library_response.json()["id"]
|
library_id = library_response.json()["id"]
|
||||||
|
|
||||||
# Create a new folder in the library
|
# Create a new folder in the library
|
||||||
new_folder = NewFoldersParam(folders=["/tmp"])
|
new_folder = NewFoldersParam(
|
||||||
|
folders=[
|
||||||
|
NewFolderParam(
|
||||||
|
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
folder_response = client.post(
|
folder_response = client.post(
|
||||||
f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")
|
f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
@ -88,10 +102,45 @@ app.dependency_overrides[get_db] = override_get_db
|
|||||||
# Setup a fixture for the FastAPI test client
|
# Setup a fixture for the FastAPI test client
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
|
# 创建所有基本表
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
# 创建 FTS 和 Vec 表
|
||||||
|
with engine.connect() as conn:
|
||||||
|
# 创建 FTS 表
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5(
|
||||||
|
id, filepath, tags, metadata,
|
||||||
|
tokenize = 'simple 0'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 Vec 表
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
|
||||||
|
embedding float[{settings.embedding.num_dim}]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
# 清理数据库
|
||||||
Base.metadata.drop_all(bind=engine)
|
Base.metadata.drop_all(bind=engine)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("DROP TABLE IF EXISTS entities_fts"))
|
||||||
|
conn.execute(text("DROP TABLE IF EXISTS entities_vec"))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
# Test the new_library endpoint
|
# Test the new_library endpoint
|
||||||
@ -119,8 +168,15 @@ def test_new_library(client):
|
|||||||
|
|
||||||
|
|
||||||
def test_list_libraries(client):
|
def test_list_libraries(client):
|
||||||
# Setup data: Create a new library
|
# Setup data: Create a new library with a folder
|
||||||
new_library = NewLibraryParam(name="Sample Library", folders=["/tmp"])
|
new_library = NewLibraryParam(
|
||||||
|
name="Sample Library",
|
||||||
|
folders=[
|
||||||
|
NewFolderParam(
|
||||||
|
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
client.post("/libraries", json=new_library.model_dump(mode="json"))
|
client.post("/libraries", json=new_library.model_dump(mode="json"))
|
||||||
|
|
||||||
# Make a GET request to the /libraries endpoint
|
# Make a GET request to the /libraries endpoint
|
||||||
@ -130,20 +186,39 @@ def test_list_libraries(client):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# Check the response data
|
# Check the response data
|
||||||
|
response_data = response.json()
|
||||||
|
for folder in response_data[0]["folders"]:
|
||||||
|
assert "last_modified_at" in folder
|
||||||
|
assert isinstance(folder["last_modified_at"], str)
|
||||||
|
del folder["last_modified_at"]
|
||||||
|
|
||||||
expected_data = [
|
expected_data = [
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"name": "Sample Library",
|
"name": "Sample Library",
|
||||||
"folders": [{"id": 1, "path": "/tmp"}],
|
"folders": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"path": "/tmp",
|
||||||
|
"type": "DEFAULT",
|
||||||
|
}
|
||||||
|
],
|
||||||
"plugins": [],
|
"plugins": [],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
assert response.json() == expected_data
|
assert response_data == expected_data
|
||||||
|
|
||||||
|
|
||||||
def test_new_entity(client):
|
def test_new_entity(client):
|
||||||
# Setup data: Create a new library
|
# Setup data: Create a new library
|
||||||
new_library = NewLibraryParam(name="Library for Entity Test", folders=["/tmp"])
|
new_library = NewLibraryParam(
|
||||||
|
name="Library for Entity Test",
|
||||||
|
folders=[
|
||||||
|
NewFolderParam(
|
||||||
|
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
library_response = client.post(
|
library_response = client.post(
|
||||||
"/libraries", json=new_library.model_dump(mode="json")
|
"/libraries", json=new_library.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
@ -226,7 +301,14 @@ def test_update_entity(client):
|
|||||||
# Test for getting an entity by filepath
|
# Test for getting an entity by filepath
|
||||||
def test_get_entity_by_filepath(client):
|
def test_get_entity_by_filepath(client):
|
||||||
# Setup data: Create a new library and entity
|
# Setup data: Create a new library and entity
|
||||||
new_library = NewLibraryParam(name="Library for Get Entity Test", folders=["/tmp"])
|
new_library = NewLibraryParam(
|
||||||
|
name="Library for Get Entity Test",
|
||||||
|
folders=[
|
||||||
|
NewFolderParam(
|
||||||
|
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
library_response = client.post(
|
library_response = client.post(
|
||||||
"/libraries", json=new_library.model_dump(mode="json")
|
"/libraries", json=new_library.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
@ -289,7 +371,13 @@ def test_list_entities_in_folder(client):
|
|||||||
)
|
)
|
||||||
library_id = library_response.json()["id"]
|
library_id = library_response.json()["id"]
|
||||||
|
|
||||||
new_folder = NewFoldersParam(folders=["/tmp"])
|
new_folder = NewFoldersParam(
|
||||||
|
folders=[
|
||||||
|
NewFolderParam(
|
||||||
|
path="/tmp", last_modified_at=datetime.now(), type=FolderType.DEFAULT
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
folder_response = client.post(
|
folder_response = client.post(
|
||||||
f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")
|
f"/libraries/{library_id}/folders", json=new_folder.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
@ -343,15 +431,45 @@ def test_list_entities_in_folder(client):
|
|||||||
def test_remove_entity(client):
|
def test_remove_entity(client):
|
||||||
library_id, _, entity_id = setup_library_with_entity(client)
|
library_id, _, entity_id = setup_library_with_entity(client)
|
||||||
|
|
||||||
|
# Verify the entity data was automatically inserted into fts and vec tables by event listeners
|
||||||
|
with engine.connect() as conn:
|
||||||
|
fts_count = conn.execute(
|
||||||
|
text("SELECT COUNT(*) FROM entities_fts WHERE id = :id"),
|
||||||
|
{"id": entity_id}
|
||||||
|
).scalar()
|
||||||
|
assert fts_count == 1, "Entity was not automatically added to entities_fts table"
|
||||||
|
|
||||||
|
vec_count = conn.execute(
|
||||||
|
text("SELECT COUNT(*) FROM entities_vec WHERE rowid = :id"),
|
||||||
|
{"id": entity_id}
|
||||||
|
).scalar()
|
||||||
|
assert vec_count == 1, "Entity was not automatically added to entities_vec table"
|
||||||
|
|
||||||
# Delete the entity
|
# Delete the entity
|
||||||
delete_response = client.delete(f"/libraries/{library_id}/entities/{entity_id}")
|
delete_response = client.delete(f"/libraries/{library_id}/entities/{entity_id}")
|
||||||
assert delete_response.status_code == 204
|
assert delete_response.status_code == 204
|
||||||
|
|
||||||
# Verify the entity is deleted
|
# Verify the entity is deleted from the main table
|
||||||
get_response = client.get(f"/libraries/{library_id}/entities/{entity_id}")
|
get_response = client.get(f"/libraries/{library_id}/entities/{entity_id}")
|
||||||
assert get_response.status_code == 404
|
assert get_response.status_code == 404
|
||||||
assert get_response.json() == {"detail": "Entity not found"}
|
assert get_response.json() == {"detail": "Entity not found"}
|
||||||
|
|
||||||
|
# Verify the entity is deleted from entities_fts and entities_vec tables
|
||||||
|
with engine.connect() as conn:
|
||||||
|
# Check entities_fts
|
||||||
|
fts_count = conn.execute(
|
||||||
|
text("SELECT COUNT(*) FROM entities_fts WHERE id = :id"),
|
||||||
|
{"id": entity_id}
|
||||||
|
).scalar()
|
||||||
|
assert fts_count == 0, "Entity was not deleted from entities_fts table"
|
||||||
|
|
||||||
|
# Check entities_vec
|
||||||
|
vec_count = conn.execute(
|
||||||
|
text("SELECT COUNT(*) FROM entities_vec WHERE rowid = :id"),
|
||||||
|
{"id": entity_id}
|
||||||
|
).scalar()
|
||||||
|
assert vec_count == 0, "Entity was not deleted from entities_vec table"
|
||||||
|
|
||||||
# Test for entity not found in the specified library
|
# Test for entity not found in the specified library
|
||||||
invalid_delete_response = client.delete(f"/libraries/{library_id}/entities/9999")
|
invalid_delete_response = client.delete(f"/libraries/{library_id}/entities/9999")
|
||||||
assert invalid_delete_response.status_code == 404
|
assert invalid_delete_response.status_code == 404
|
||||||
@ -374,7 +492,15 @@ def test_add_folder_to_library(client):
|
|||||||
library_id = library_response.json()["id"]
|
library_id = library_response.json()["id"]
|
||||||
|
|
||||||
# Add a new folder to the library
|
# Add a new folder to the library
|
||||||
new_folders = NewFoldersParam(folders=[tmp_folder_path])
|
new_folders = NewFoldersParam(
|
||||||
|
folders=[
|
||||||
|
NewFolderParam(
|
||||||
|
path=tmp_folder_path,
|
||||||
|
last_modified_at=datetime.now(),
|
||||||
|
type=FolderType.DEFAULT,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
folder_response = client.post(
|
folder_response = client.post(
|
||||||
f"/libraries/{library_id}/folders", json=new_folders.model_dump(mode="json")
|
f"/libraries/{library_id}/folders", json=new_folders.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
@ -615,6 +741,17 @@ def test_patch_entity_metadata_entries(client):
|
|||||||
# Check the response data
|
# Check the response data
|
||||||
patched_entity_data = patch_response.json()
|
patched_entity_data = patch_response.json()
|
||||||
expected_data = load_fixture("patch_entity_metadata_response.json")
|
expected_data = load_fixture("patch_entity_metadata_response.json")
|
||||||
|
|
||||||
|
# 检查并移除 last_scan_at
|
||||||
|
assert "last_scan_at" in patched_entity_data
|
||||||
|
assert isinstance(patched_entity_data["last_scan_at"], str)
|
||||||
|
|
||||||
|
datetime.fromisoformat(patched_entity_data["last_scan_at"].replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
del patched_entity_data["last_scan_at"]
|
||||||
|
if "last_scan_at" in expected_data:
|
||||||
|
del expected_data["last_scan_at"]
|
||||||
|
|
||||||
assert patched_entity_data == expected_data
|
assert patched_entity_data == expected_data
|
||||||
|
|
||||||
# Update the "author" attribute of the entity
|
# Update the "author" attribute of the entity
|
||||||
|
Loading…
x
Reference in New Issue
Block a user