feat(index): support recreate fts and vec tables when reindexing

This commit is contained in:
arkohut 2024-10-14 15:28:25 +08:00
parent de7cbc878a
commit 8c27354d86
5 changed files with 142 additions and 37 deletions

View File

@ -22,6 +22,7 @@ from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from concurrent.futures import ThreadPoolExecutor
from memos.models import recreate_fts_and_vec_tables
from memos.utils import get_image_metadata
from memos.schemas import MetadataSource
from memos.logging_config import LOGGING_CONFIG
@ -536,6 +537,7 @@ async def update_entity(
def reindex(
library_id: int,
folders: List[int] = typer.Option(None, "--folder", "-f"),
force: bool = typer.Option(False, "--force", help="Force recreate FTS and vector tables before reindexing"),
):
print(f"Reindexing library {library_id}")
@ -556,9 +558,27 @@ def reindex(
else:
library_folders = library["folders"]
def process_folders():
with httpx.Client(timeout=60) as client:
# Iterate through folders
if force:
print("Force flag is set. Recreating FTS and vector tables...")
recreate_fts_and_vec_tables()
print("FTS and vector tables have been recreated.")
with httpx.Client(timeout=60) as client:
total_entities = 0
# Get total entity count for all folders
for folder in library_folders:
response = client.get(
f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities",
params={"limit": 1, "offset": 0},
)
if response.status_code == 200:
total_entities += int(response.headers.get("X-Total-Count", 0))
else:
print(f"Failed to get entity count for folder {folder['id']}: {response.status_code} - {response.text}")
# Now process entities with a progress bar
with tqdm(total=total_entities, desc="Reindexing entities") as pbar:
for folder in library_folders:
print(f"Processing folder: {folder['id']}")
@ -587,15 +607,14 @@ def reindex(
if update_response.status_code != 204:
print(f"Failed to update last_scan_at for entity {entity['id']}: {update_response.status_code} - {update_response.text}")
else:
print(f"Updated last_scan_at for entity {entity['id']}")
scanned_entities.add(entity["id"])
scanned_entities.add(entity["id"])
pbar.update(1)
offset += limit
process_folders()
print(f"Reindexing completed for library {library_id}")
async def check_and_index_entity(client, entity_id, entity_last_scan_at):
try:

View File

@ -5,7 +5,7 @@ from datetime import datetime, timedelta
import httpx
import typer
from .config import settings
from .config import settings, display_config
from .models import init_database
from .initialize_typesense import init_typesense
from .record import (
@ -168,7 +168,9 @@ def typsense_index_default_library(
@app.command("reindex")
def reindex_default_library():
def reindex_default_library(
force: bool = typer.Option(False, "--force", help="Force recreate FTS and vector tables before reindexing")
):
"""
Reindex the default library for memos.
"""
@ -189,7 +191,7 @@ def reindex_default_library():
# Reindex the library
print(f"Reindexing library: {default_library['name']}")
reindex(default_library["id"])
reindex(default_library["id"], force=force)
@app.command("record")
@ -485,14 +487,14 @@ def ps():
create_time = datetime.fromtimestamp(process.info['create_time']).strftime('%Y-%m-%d %H:%M:%S')
running_time = str(timedelta(seconds=int(time.time() - process.info['create_time'])))
table_data.append([
service.capitalize(),
service,
"Running",
process.info['pid'],
create_time,
running_time
])
else:
table_data.append([service.capitalize(), "Not Running", "-", "-", "-"])
table_data.append([service, "Not Running", "-", "-", "-"])
headers = ["Name", "Status", "PID", "Started At", "Running For"]
typer.echo(tabulate(table_data, headers=headers, tablefmt="plain"))
@ -559,5 +561,11 @@ def start():
typer.echo("Unsupported operating system.")
@app.command()
def config():
"""Show current configuration settings"""
display_config()
if __name__ == "__main__":
app()

View File

@ -11,6 +11,7 @@ from pydantic import BaseModel, SecretStr
import yaml
from collections import OrderedDict
import io
import typer
class VLMSettings(BaseModel):
@ -22,7 +23,7 @@ class VLMSettings(BaseModel):
# some vlm models do not support webp
force_jpeg: bool = True
# prompt for vlm to extract caption
prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等"
prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等"
class OCRSettings(BaseModel):
@ -118,11 +119,13 @@ yaml.add_representer(OrderedDict, dict_representer)
def secret_str_representer(dumper, data):
return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value())
# Custom constructor for SecretStr
def secret_str_constructor(loader, node):
value = loader.construct_scalar(node)
return SecretStr(value)
# Register the representer and constructor only for specific fields
yaml.add_representer(SecretStr, secret_str_representer)
@ -132,22 +135,25 @@ def create_default_config():
if not config_path.exists():
settings = Settings()
os.makedirs(config_path.parent, exist_ok=True)
# 将设置转换为字典并确保顺序
settings_dict = settings.model_dump()
ordered_settings = OrderedDict(
(key, settings_dict[key]) for key in settings.model_fields.keys()
)
# 使用 io.StringIO 作为中间步骤
with io.StringIO() as string_buffer:
yaml.dump(ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper)
yaml.dump(
ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper
)
yaml_content = string_buffer.getvalue()
# 将内容写入文件,确保使用 UTF-8 编码
with open(config_path, "w", encoding="utf-8") as f:
f.write(yaml_content)
# Create default config if it doesn't exist
create_default_config()
@ -164,3 +170,36 @@ TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name
# Function to get the database path from environment variable or default
def get_database_path():
return settings.database_path
def format_value(value):
if isinstance(
value, (VLMSettings, OCRSettings, EmbeddingSettings, TypesenseSettings)
):
return (
"{\n"
+ "\n".join(f" {k}: {v}" for k, v in value.model_dump().items())
+ "\n }"
)
elif isinstance(value, (list, tuple)):
return f"[{', '.join(map(str, value))}]"
elif isinstance(value, SecretStr):
return "********" # Hide the actual value of SecretStr
else:
return str(value)
def display_config():
settings = Settings()
config_dict = settings.model_dump()
max_key_length = max(len(key) for key in config_dict.keys())
typer.echo("Current configuration settings:")
for key, value in config_dict.items():
formatted_value = format_value(value)
if "\n" in formatted_value:
typer.echo(f"{key}:")
for line in formatted_value.split("\n"):
typer.echo(f" {line}")
else:
typer.echo(f"{key.ljust(max_key_length)} : {formatted_value}")

View File

@ -33,7 +33,7 @@ def init_embedding_model():
model_dir = settings.embedding.model
logger.info(f"Using model: {model_dir}")
model = SentenceTransformer(model_dir, trust_remote_code=True, truncate_dim=768)
model = SentenceTransformer(model_dir, trust_remote_code=True)
model.to(device)
logger.info(f"Embedding model initialized on device: {device}")

View File

@ -9,7 +9,6 @@ from sqlalchemy import (
func,
Index,
event,
DDL,
)
from datetime import datetime
from sqlalchemy.orm import relationship, DeclarativeBase, Mapped, mapped_column, Session
@ -18,10 +17,8 @@ from .schemas import MetadataSource, MetadataType, FolderType
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from .config import get_database_path, settings
import numpy as np
from sqlalchemy import text
import sqlite_vec
import os
import sys
from pathlib import Path
import json
@ -202,6 +199,50 @@ def load_extension(dbapi_conn, connection_record):
dbapi_conn.execute("PRAGMA journal_mode=WAL")
def recreate_fts_and_vec_tables():
"""Recreate the entities_fts and entities_vec tables without repopulating data."""
db_path = get_database_path()
engine = create_engine(f"sqlite:///{db_path}")
event.listen(engine, "connect", load_extension)
Session = sessionmaker(bind=engine)
with Session() as session:
try:
# Drop existing tables
session.execute(text("DROP TABLE IF EXISTS entities_fts"))
session.execute(text("DROP TABLE IF EXISTS entities_vec"))
# Recreate entities_fts table
session.execute(
text(
"""
CREATE VIRTUAL TABLE entities_fts USING fts5(
id, filepath, tags, metadata,
tokenize = 'simple 0'
)
"""
)
)
# Recreate entities_vec table
session.execute(
text(
f"""
CREATE VIRTUAL TABLE entities_vec USING vec0(
embedding float[{settings.embedding.num_dim}]
)
"""
)
)
session.commit()
print("Successfully recreated entities_fts and entities_vec tables.")
except Exception as e:
session.rollback()
print(f"Error recreating tables: {e}")
def init_database():
"""Initialize the database."""
db_path = get_database_path()
@ -214,25 +255,25 @@ def init_database():
Base.metadata.create_all(engine)
print(f"Database initialized successfully at {db_path}")
# Create FTS and Vec tables
with engine.connect() as conn:
conn.execute(
DDL(
text(
"""
CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5(
id, filepath, tags, metadata,
tokenize = 'simple 0'
)
CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5(
id, filepath, tags, metadata,
tokenize = 'simple 0'
)
"""
)
)
with engine.connect() as conn:
conn.execute(
DDL(
text(
f"""
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
embedding float[{settings.embedding.num_dim}]
)
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
embedding float[{settings.embedding.num_dim}]
)
"""
)
)
@ -317,9 +358,7 @@ async def update_or_insert_entities_vec(session, target_id, embedding):
try:
# First, try to update the existing row
result = session.execute(
text(
"UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"
),
text("UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"),
{
"id": target_id,
"embedding": serialize_float32(embedding),
@ -337,7 +376,7 @@ async def update_or_insert_entities_vec(session, target_id, embedding):
"embedding": serialize_float32(embedding),
},
)
session.commit()
except Exception as e:
print(f"Error updating entities_vec: {e}")
@ -379,7 +418,7 @@ def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata):
"metadata": metadata,
},
)
session.commit()
except Exception as e:
print(f"Error updating entities_fts: {e}")