mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(index): support recreate fts and vec tables when reindexing
This commit is contained in:
parent
de7cbc878a
commit
8c27354d86
@ -22,6 +22,7 @@ from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from memos.models import recreate_fts_and_vec_tables
|
||||
from memos.utils import get_image_metadata
|
||||
from memos.schemas import MetadataSource
|
||||
from memos.logging_config import LOGGING_CONFIG
|
||||
@ -536,6 +537,7 @@ async def update_entity(
|
||||
def reindex(
|
||||
library_id: int,
|
||||
folders: List[int] = typer.Option(None, "--folder", "-f"),
|
||||
force: bool = typer.Option(False, "--force", help="Force recreate FTS and vector tables before reindexing"),
|
||||
):
|
||||
print(f"Reindexing library {library_id}")
|
||||
|
||||
@ -556,9 +558,27 @@ def reindex(
|
||||
else:
|
||||
library_folders = library["folders"]
|
||||
|
||||
def process_folders():
|
||||
with httpx.Client(timeout=60) as client:
|
||||
# Iterate through folders
|
||||
if force:
|
||||
print("Force flag is set. Recreating FTS and vector tables...")
|
||||
recreate_fts_and_vec_tables()
|
||||
print("FTS and vector tables have been recreated.")
|
||||
|
||||
with httpx.Client(timeout=60) as client:
|
||||
total_entities = 0
|
||||
|
||||
# Get total entity count for all folders
|
||||
for folder in library_folders:
|
||||
response = client.get(
|
||||
f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities",
|
||||
params={"limit": 1, "offset": 0},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
total_entities += int(response.headers.get("X-Total-Count", 0))
|
||||
else:
|
||||
print(f"Failed to get entity count for folder {folder['id']}: {response.status_code} - {response.text}")
|
||||
|
||||
# Now process entities with a progress bar
|
||||
with tqdm(total=total_entities, desc="Reindexing entities") as pbar:
|
||||
for folder in library_folders:
|
||||
print(f"Processing folder: {folder['id']}")
|
||||
|
||||
@ -587,15 +607,14 @@ def reindex(
|
||||
if update_response.status_code != 204:
|
||||
print(f"Failed to update last_scan_at for entity {entity['id']}: {update_response.status_code} - {update_response.text}")
|
||||
else:
|
||||
print(f"Updated last_scan_at for entity {entity['id']}")
|
||||
|
||||
scanned_entities.add(entity["id"])
|
||||
scanned_entities.add(entity["id"])
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
offset += limit
|
||||
|
||||
process_folders()
|
||||
print(f"Reindexing completed for library {library_id}")
|
||||
|
||||
|
||||
|
||||
async def check_and_index_entity(client, entity_id, entity_last_scan_at):
|
||||
try:
|
||||
|
@ -5,7 +5,7 @@ from datetime import datetime, timedelta
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
from .config import settings
|
||||
from .config import settings, display_config
|
||||
from .models import init_database
|
||||
from .initialize_typesense import init_typesense
|
||||
from .record import (
|
||||
@ -168,7 +168,9 @@ def typsense_index_default_library(
|
||||
|
||||
|
||||
@app.command("reindex")
|
||||
def reindex_default_library():
|
||||
def reindex_default_library(
|
||||
force: bool = typer.Option(False, "--force", help="Force recreate FTS and vector tables before reindexing")
|
||||
):
|
||||
"""
|
||||
Reindex the default library for memos.
|
||||
"""
|
||||
@ -189,7 +191,7 @@ def reindex_default_library():
|
||||
|
||||
# Reindex the library
|
||||
print(f"Reindexing library: {default_library['name']}")
|
||||
reindex(default_library["id"])
|
||||
reindex(default_library["id"], force=force)
|
||||
|
||||
|
||||
@app.command("record")
|
||||
@ -485,14 +487,14 @@ def ps():
|
||||
create_time = datetime.fromtimestamp(process.info['create_time']).strftime('%Y-%m-%d %H:%M:%S')
|
||||
running_time = str(timedelta(seconds=int(time.time() - process.info['create_time'])))
|
||||
table_data.append([
|
||||
service.capitalize(),
|
||||
service,
|
||||
"Running",
|
||||
process.info['pid'],
|
||||
create_time,
|
||||
running_time
|
||||
])
|
||||
else:
|
||||
table_data.append([service.capitalize(), "Not Running", "-", "-", "-"])
|
||||
table_data.append([service, "Not Running", "-", "-", "-"])
|
||||
|
||||
headers = ["Name", "Status", "PID", "Started At", "Running For"]
|
||||
typer.echo(tabulate(table_data, headers=headers, tablefmt="plain"))
|
||||
@ -559,5 +561,11 @@ def start():
|
||||
typer.echo("Unsupported operating system.")
|
||||
|
||||
|
||||
@app.command()
|
||||
def config():
|
||||
"""Show current configuration settings"""
|
||||
display_config()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
@ -11,6 +11,7 @@ from pydantic import BaseModel, SecretStr
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
import io
|
||||
import typer
|
||||
|
||||
|
||||
class VLMSettings(BaseModel):
|
||||
@ -22,7 +23,7 @@ class VLMSettings(BaseModel):
|
||||
# some vlm models do not support webp
|
||||
force_jpeg: bool = True
|
||||
# prompt for vlm to extract caption
|
||||
prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等"
|
||||
prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等"
|
||||
|
||||
|
||||
class OCRSettings(BaseModel):
|
||||
@ -118,11 +119,13 @@ yaml.add_representer(OrderedDict, dict_representer)
|
||||
def secret_str_representer(dumper, data):
|
||||
return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value())
|
||||
|
||||
|
||||
# Custom constructor for SecretStr
|
||||
def secret_str_constructor(loader, node):
|
||||
value = loader.construct_scalar(node)
|
||||
return SecretStr(value)
|
||||
|
||||
|
||||
# Register the representer and constructor only for specific fields
|
||||
yaml.add_representer(SecretStr, secret_str_representer)
|
||||
|
||||
@ -132,22 +135,25 @@ def create_default_config():
|
||||
if not config_path.exists():
|
||||
settings = Settings()
|
||||
os.makedirs(config_path.parent, exist_ok=True)
|
||||
|
||||
|
||||
# 将设置转换为字典并确保顺序
|
||||
settings_dict = settings.model_dump()
|
||||
ordered_settings = OrderedDict(
|
||||
(key, settings_dict[key]) for key in settings.model_fields.keys()
|
||||
)
|
||||
|
||||
|
||||
# 使用 io.StringIO 作为中间步骤
|
||||
with io.StringIO() as string_buffer:
|
||||
yaml.dump(ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper)
|
||||
yaml.dump(
|
||||
ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper
|
||||
)
|
||||
yaml_content = string_buffer.getvalue()
|
||||
|
||||
|
||||
# 将内容写入文件,确保使用 UTF-8 编码
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(yaml_content)
|
||||
|
||||
|
||||
# Create default config if it doesn't exist
|
||||
create_default_config()
|
||||
|
||||
@ -164,3 +170,36 @@ TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name
|
||||
# Function to get the database path from environment variable or default
|
||||
def get_database_path():
|
||||
return settings.database_path
|
||||
|
||||
|
||||
def format_value(value):
|
||||
if isinstance(
|
||||
value, (VLMSettings, OCRSettings, EmbeddingSettings, TypesenseSettings)
|
||||
):
|
||||
return (
|
||||
"{\n"
|
||||
+ "\n".join(f" {k}: {v}" for k, v in value.model_dump().items())
|
||||
+ "\n }"
|
||||
)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return f"[{', '.join(map(str, value))}]"
|
||||
elif isinstance(value, SecretStr):
|
||||
return "********" # Hide the actual value of SecretStr
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
|
||||
def display_config():
|
||||
settings = Settings()
|
||||
config_dict = settings.model_dump()
|
||||
max_key_length = max(len(key) for key in config_dict.keys())
|
||||
|
||||
typer.echo("Current configuration settings:")
|
||||
for key, value in config_dict.items():
|
||||
formatted_value = format_value(value)
|
||||
if "\n" in formatted_value:
|
||||
typer.echo(f"{key}:")
|
||||
for line in formatted_value.split("\n"):
|
||||
typer.echo(f" {line}")
|
||||
else:
|
||||
typer.echo(f"{key.ljust(max_key_length)} : {formatted_value}")
|
||||
|
@ -33,7 +33,7 @@ def init_embedding_model():
|
||||
model_dir = settings.embedding.model
|
||||
logger.info(f"Using model: {model_dir}")
|
||||
|
||||
model = SentenceTransformer(model_dir, trust_remote_code=True, truncate_dim=768)
|
||||
model = SentenceTransformer(model_dir, trust_remote_code=True)
|
||||
model.to(device)
|
||||
logger.info(f"Embedding model initialized on device: {device}")
|
||||
|
||||
|
@ -9,7 +9,6 @@ from sqlalchemy import (
|
||||
func,
|
||||
Index,
|
||||
event,
|
||||
DDL,
|
||||
)
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import relationship, DeclarativeBase, Mapped, mapped_column, Session
|
||||
@ -18,10 +17,8 @@ from .schemas import MetadataSource, MetadataType, FolderType
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from .config import get_database_path, settings
|
||||
import numpy as np
|
||||
from sqlalchemy import text
|
||||
import sqlite_vec
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
@ -202,6 +199,50 @@ def load_extension(dbapi_conn, connection_record):
|
||||
dbapi_conn.execute("PRAGMA journal_mode=WAL")
|
||||
|
||||
|
||||
def recreate_fts_and_vec_tables():
|
||||
"""Recreate the entities_fts and entities_vec tables without repopulating data."""
|
||||
db_path = get_database_path()
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
event.listen(engine, "connect", load_extension)
|
||||
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
with Session() as session:
|
||||
try:
|
||||
# Drop existing tables
|
||||
session.execute(text("DROP TABLE IF EXISTS entities_fts"))
|
||||
session.execute(text("DROP TABLE IF EXISTS entities_vec"))
|
||||
|
||||
# Recreate entities_fts table
|
||||
session.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE VIRTUAL TABLE entities_fts USING fts5(
|
||||
id, filepath, tags, metadata,
|
||||
tokenize = 'simple 0'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Recreate entities_vec table
|
||||
session.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE VIRTUAL TABLE entities_vec USING vec0(
|
||||
embedding float[{settings.embedding.num_dim}]
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
session.commit()
|
||||
print("Successfully recreated entities_fts and entities_vec tables.")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
print(f"Error recreating tables: {e}")
|
||||
|
||||
|
||||
def init_database():
|
||||
"""Initialize the database."""
|
||||
db_path = get_database_path()
|
||||
@ -214,25 +255,25 @@ def init_database():
|
||||
Base.metadata.create_all(engine)
|
||||
print(f"Database initialized successfully at {db_path}")
|
||||
|
||||
# Create FTS and Vec tables
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
DDL(
|
||||
text(
|
||||
"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5(
|
||||
id, filepath, tags, metadata,
|
||||
tokenize = 'simple 0'
|
||||
)
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5(
|
||||
id, filepath, tags, metadata,
|
||||
tokenize = 'simple 0'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
DDL(
|
||||
text(
|
||||
f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
|
||||
embedding float[{settings.embedding.num_dim}]
|
||||
)
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
|
||||
embedding float[{settings.embedding.num_dim}]
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
@ -317,9 +358,7 @@ async def update_or_insert_entities_vec(session, target_id, embedding):
|
||||
try:
|
||||
# First, try to update the existing row
|
||||
result = session.execute(
|
||||
text(
|
||||
"UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"
|
||||
),
|
||||
text("UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"),
|
||||
{
|
||||
"id": target_id,
|
||||
"embedding": serialize_float32(embedding),
|
||||
@ -337,7 +376,7 @@ async def update_or_insert_entities_vec(session, target_id, embedding):
|
||||
"embedding": serialize_float32(embedding),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
print(f"Error updating entities_vec: {e}")
|
||||
@ -379,7 +418,7 @@ def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata):
|
||||
"metadata": metadata,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
print(f"Error updating entities_fts: {e}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user