mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
feat(index): support recreate fts and vec tables when reindexing
This commit is contained in:
parent
de7cbc878a
commit
8c27354d86
@ -22,6 +22,7 @@ from watchdog.observers import Observer
|
|||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from memos.models import recreate_fts_and_vec_tables
|
||||||
from memos.utils import get_image_metadata
|
from memos.utils import get_image_metadata
|
||||||
from memos.schemas import MetadataSource
|
from memos.schemas import MetadataSource
|
||||||
from memos.logging_config import LOGGING_CONFIG
|
from memos.logging_config import LOGGING_CONFIG
|
||||||
@ -536,6 +537,7 @@ async def update_entity(
|
|||||||
def reindex(
|
def reindex(
|
||||||
library_id: int,
|
library_id: int,
|
||||||
folders: List[int] = typer.Option(None, "--folder", "-f"),
|
folders: List[int] = typer.Option(None, "--folder", "-f"),
|
||||||
|
force: bool = typer.Option(False, "--force", help="Force recreate FTS and vector tables before reindexing"),
|
||||||
):
|
):
|
||||||
print(f"Reindexing library {library_id}")
|
print(f"Reindexing library {library_id}")
|
||||||
|
|
||||||
@ -556,9 +558,27 @@ def reindex(
|
|||||||
else:
|
else:
|
||||||
library_folders = library["folders"]
|
library_folders = library["folders"]
|
||||||
|
|
||||||
def process_folders():
|
if force:
|
||||||
with httpx.Client(timeout=60) as client:
|
print("Force flag is set. Recreating FTS and vector tables...")
|
||||||
# Iterate through folders
|
recreate_fts_and_vec_tables()
|
||||||
|
print("FTS and vector tables have been recreated.")
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60) as client:
|
||||||
|
total_entities = 0
|
||||||
|
|
||||||
|
# Get total entity count for all folders
|
||||||
|
for folder in library_folders:
|
||||||
|
response = client.get(
|
||||||
|
f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities",
|
||||||
|
params={"limit": 1, "offset": 0},
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
total_entities += int(response.headers.get("X-Total-Count", 0))
|
||||||
|
else:
|
||||||
|
print(f"Failed to get entity count for folder {folder['id']}: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
# Now process entities with a progress bar
|
||||||
|
with tqdm(total=total_entities, desc="Reindexing entities") as pbar:
|
||||||
for folder in library_folders:
|
for folder in library_folders:
|
||||||
print(f"Processing folder: {folder['id']}")
|
print(f"Processing folder: {folder['id']}")
|
||||||
|
|
||||||
@ -587,15 +607,14 @@ def reindex(
|
|||||||
if update_response.status_code != 204:
|
if update_response.status_code != 204:
|
||||||
print(f"Failed to update last_scan_at for entity {entity['id']}: {update_response.status_code} - {update_response.text}")
|
print(f"Failed to update last_scan_at for entity {entity['id']}: {update_response.status_code} - {update_response.text}")
|
||||||
else:
|
else:
|
||||||
print(f"Updated last_scan_at for entity {entity['id']}")
|
scanned_entities.add(entity["id"])
|
||||||
|
|
||||||
scanned_entities.add(entity["id"])
|
pbar.update(1)
|
||||||
|
|
||||||
offset += limit
|
offset += limit
|
||||||
|
|
||||||
process_folders()
|
|
||||||
print(f"Reindexing completed for library {library_id}")
|
print(f"Reindexing completed for library {library_id}")
|
||||||
|
|
||||||
|
|
||||||
async def check_and_index_entity(client, entity_id, entity_last_scan_at):
|
async def check_and_index_entity(client, entity_id, entity_last_scan_at):
|
||||||
try:
|
try:
|
||||||
|
@ -5,7 +5,7 @@ from datetime import datetime, timedelta
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import typer
|
import typer
|
||||||
from .config import settings
|
from .config import settings, display_config
|
||||||
from .models import init_database
|
from .models import init_database
|
||||||
from .initialize_typesense import init_typesense
|
from .initialize_typesense import init_typesense
|
||||||
from .record import (
|
from .record import (
|
||||||
@ -168,7 +168,9 @@ def typsense_index_default_library(
|
|||||||
|
|
||||||
|
|
||||||
@app.command("reindex")
|
@app.command("reindex")
|
||||||
def reindex_default_library():
|
def reindex_default_library(
|
||||||
|
force: bool = typer.Option(False, "--force", help="Force recreate FTS and vector tables before reindexing")
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Reindex the default library for memos.
|
Reindex the default library for memos.
|
||||||
"""
|
"""
|
||||||
@ -189,7 +191,7 @@ def reindex_default_library():
|
|||||||
|
|
||||||
# Reindex the library
|
# Reindex the library
|
||||||
print(f"Reindexing library: {default_library['name']}")
|
print(f"Reindexing library: {default_library['name']}")
|
||||||
reindex(default_library["id"])
|
reindex(default_library["id"], force=force)
|
||||||
|
|
||||||
|
|
||||||
@app.command("record")
|
@app.command("record")
|
||||||
@ -485,14 +487,14 @@ def ps():
|
|||||||
create_time = datetime.fromtimestamp(process.info['create_time']).strftime('%Y-%m-%d %H:%M:%S')
|
create_time = datetime.fromtimestamp(process.info['create_time']).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
running_time = str(timedelta(seconds=int(time.time() - process.info['create_time'])))
|
running_time = str(timedelta(seconds=int(time.time() - process.info['create_time'])))
|
||||||
table_data.append([
|
table_data.append([
|
||||||
service.capitalize(),
|
service,
|
||||||
"Running",
|
"Running",
|
||||||
process.info['pid'],
|
process.info['pid'],
|
||||||
create_time,
|
create_time,
|
||||||
running_time
|
running_time
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
table_data.append([service.capitalize(), "Not Running", "-", "-", "-"])
|
table_data.append([service, "Not Running", "-", "-", "-"])
|
||||||
|
|
||||||
headers = ["Name", "Status", "PID", "Started At", "Running For"]
|
headers = ["Name", "Status", "PID", "Started At", "Running For"]
|
||||||
typer.echo(tabulate(table_data, headers=headers, tablefmt="plain"))
|
typer.echo(tabulate(table_data, headers=headers, tablefmt="plain"))
|
||||||
@ -559,5 +561,11 @@ def start():
|
|||||||
typer.echo("Unsupported operating system.")
|
typer.echo("Unsupported operating system.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def config():
|
||||||
|
"""Show current configuration settings"""
|
||||||
|
display_config()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app()
|
app()
|
@ -11,6 +11,7 @@ from pydantic import BaseModel, SecretStr
|
|||||||
import yaml
|
import yaml
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import io
|
import io
|
||||||
|
import typer
|
||||||
|
|
||||||
|
|
||||||
class VLMSettings(BaseModel):
|
class VLMSettings(BaseModel):
|
||||||
@ -22,7 +23,7 @@ class VLMSettings(BaseModel):
|
|||||||
# some vlm models do not support webp
|
# some vlm models do not support webp
|
||||||
force_jpeg: bool = True
|
force_jpeg: bool = True
|
||||||
# prompt for vlm to extract caption
|
# prompt for vlm to extract caption
|
||||||
prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等"
|
prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等"
|
||||||
|
|
||||||
|
|
||||||
class OCRSettings(BaseModel):
|
class OCRSettings(BaseModel):
|
||||||
@ -118,11 +119,13 @@ yaml.add_representer(OrderedDict, dict_representer)
|
|||||||
def secret_str_representer(dumper, data):
|
def secret_str_representer(dumper, data):
|
||||||
return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value())
|
return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value())
|
||||||
|
|
||||||
|
|
||||||
# Custom constructor for SecretStr
|
# Custom constructor for SecretStr
|
||||||
def secret_str_constructor(loader, node):
|
def secret_str_constructor(loader, node):
|
||||||
value = loader.construct_scalar(node)
|
value = loader.construct_scalar(node)
|
||||||
return SecretStr(value)
|
return SecretStr(value)
|
||||||
|
|
||||||
|
|
||||||
# Register the representer and constructor only for specific fields
|
# Register the representer and constructor only for specific fields
|
||||||
yaml.add_representer(SecretStr, secret_str_representer)
|
yaml.add_representer(SecretStr, secret_str_representer)
|
||||||
|
|
||||||
@ -132,22 +135,25 @@ def create_default_config():
|
|||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
os.makedirs(config_path.parent, exist_ok=True)
|
os.makedirs(config_path.parent, exist_ok=True)
|
||||||
|
|
||||||
# 将设置转换为字典并确保顺序
|
# 将设置转换为字典并确保顺序
|
||||||
settings_dict = settings.model_dump()
|
settings_dict = settings.model_dump()
|
||||||
ordered_settings = OrderedDict(
|
ordered_settings = OrderedDict(
|
||||||
(key, settings_dict[key]) for key in settings.model_fields.keys()
|
(key, settings_dict[key]) for key in settings.model_fields.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用 io.StringIO 作为中间步骤
|
# 使用 io.StringIO 作为中间步骤
|
||||||
with io.StringIO() as string_buffer:
|
with io.StringIO() as string_buffer:
|
||||||
yaml.dump(ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper)
|
yaml.dump(
|
||||||
|
ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper
|
||||||
|
)
|
||||||
yaml_content = string_buffer.getvalue()
|
yaml_content = string_buffer.getvalue()
|
||||||
|
|
||||||
# 将内容写入文件,确保使用 UTF-8 编码
|
# 将内容写入文件,确保使用 UTF-8 编码
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
f.write(yaml_content)
|
f.write(yaml_content)
|
||||||
|
|
||||||
|
|
||||||
# Create default config if it doesn't exist
|
# Create default config if it doesn't exist
|
||||||
create_default_config()
|
create_default_config()
|
||||||
|
|
||||||
@ -164,3 +170,36 @@ TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name
|
|||||||
# Function to get the database path from environment variable or default
|
# Function to get the database path from environment variable or default
|
||||||
def get_database_path():
|
def get_database_path():
|
||||||
return settings.database_path
|
return settings.database_path
|
||||||
|
|
||||||
|
|
||||||
|
def format_value(value):
|
||||||
|
if isinstance(
|
||||||
|
value, (VLMSettings, OCRSettings, EmbeddingSettings, TypesenseSettings)
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
"{\n"
|
||||||
|
+ "\n".join(f" {k}: {v}" for k, v in value.model_dump().items())
|
||||||
|
+ "\n }"
|
||||||
|
)
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
return f"[{', '.join(map(str, value))}]"
|
||||||
|
elif isinstance(value, SecretStr):
|
||||||
|
return "********" # Hide the actual value of SecretStr
|
||||||
|
else:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def display_config():
|
||||||
|
settings = Settings()
|
||||||
|
config_dict = settings.model_dump()
|
||||||
|
max_key_length = max(len(key) for key in config_dict.keys())
|
||||||
|
|
||||||
|
typer.echo("Current configuration settings:")
|
||||||
|
for key, value in config_dict.items():
|
||||||
|
formatted_value = format_value(value)
|
||||||
|
if "\n" in formatted_value:
|
||||||
|
typer.echo(f"{key}:")
|
||||||
|
for line in formatted_value.split("\n"):
|
||||||
|
typer.echo(f" {line}")
|
||||||
|
else:
|
||||||
|
typer.echo(f"{key.ljust(max_key_length)} : {formatted_value}")
|
||||||
|
@ -33,7 +33,7 @@ def init_embedding_model():
|
|||||||
model_dir = settings.embedding.model
|
model_dir = settings.embedding.model
|
||||||
logger.info(f"Using model: {model_dir}")
|
logger.info(f"Using model: {model_dir}")
|
||||||
|
|
||||||
model = SentenceTransformer(model_dir, trust_remote_code=True, truncate_dim=768)
|
model = SentenceTransformer(model_dir, trust_remote_code=True)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
logger.info(f"Embedding model initialized on device: {device}")
|
logger.info(f"Embedding model initialized on device: {device}")
|
||||||
|
|
||||||
|
@ -9,7 +9,6 @@ from sqlalchemy import (
|
|||||||
func,
|
func,
|
||||||
Index,
|
Index,
|
||||||
event,
|
event,
|
||||||
DDL,
|
|
||||||
)
|
)
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy.orm import relationship, DeclarativeBase, Mapped, mapped_column, Session
|
from sqlalchemy.orm import relationship, DeclarativeBase, Mapped, mapped_column, Session
|
||||||
@ -18,10 +17,8 @@ from .schemas import MetadataSource, MetadataType, FolderType
|
|||||||
from sqlalchemy.exc import OperationalError
|
from sqlalchemy.exc import OperationalError
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from .config import get_database_path, settings
|
from .config import get_database_path, settings
|
||||||
import numpy as np
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
import sqlite_vec
|
import sqlite_vec
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
@ -202,6 +199,50 @@ def load_extension(dbapi_conn, connection_record):
|
|||||||
dbapi_conn.execute("PRAGMA journal_mode=WAL")
|
dbapi_conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
|
||||||
|
|
||||||
|
def recreate_fts_and_vec_tables():
|
||||||
|
"""Recreate the entities_fts and entities_vec tables without repopulating data."""
|
||||||
|
db_path = get_database_path()
|
||||||
|
engine = create_engine(f"sqlite:///{db_path}")
|
||||||
|
event.listen(engine, "connect", load_extension)
|
||||||
|
|
||||||
|
Session = sessionmaker(bind=engine)
|
||||||
|
|
||||||
|
with Session() as session:
|
||||||
|
try:
|
||||||
|
# Drop existing tables
|
||||||
|
session.execute(text("DROP TABLE IF EXISTS entities_fts"))
|
||||||
|
session.execute(text("DROP TABLE IF EXISTS entities_vec"))
|
||||||
|
|
||||||
|
# Recreate entities_fts table
|
||||||
|
session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE VIRTUAL TABLE entities_fts USING fts5(
|
||||||
|
id, filepath, tags, metadata,
|
||||||
|
tokenize = 'simple 0'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recreate entities_vec table
|
||||||
|
session.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
CREATE VIRTUAL TABLE entities_vec USING vec0(
|
||||||
|
embedding float[{settings.embedding.num_dim}]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
print("Successfully recreated entities_fts and entities_vec tables.")
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
print(f"Error recreating tables: {e}")
|
||||||
|
|
||||||
|
|
||||||
def init_database():
|
def init_database():
|
||||||
"""Initialize the database."""
|
"""Initialize the database."""
|
||||||
db_path = get_database_path()
|
db_path = get_database_path()
|
||||||
@ -214,25 +255,25 @@ def init_database():
|
|||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
print(f"Database initialized successfully at {db_path}")
|
print(f"Database initialized successfully at {db_path}")
|
||||||
|
|
||||||
|
# Create FTS and Vec tables
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
DDL(
|
text(
|
||||||
"""
|
"""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5(
|
CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5(
|
||||||
id, filepath, tags, metadata,
|
id, filepath, tags, metadata,
|
||||||
tokenize = 'simple 0'
|
tokenize = 'simple 0'
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
with engine.connect() as conn:
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
DDL(
|
text(
|
||||||
f"""
|
f"""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
|
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
|
||||||
embedding float[{settings.embedding.num_dim}]
|
embedding float[{settings.embedding.num_dim}]
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -317,9 +358,7 @@ async def update_or_insert_entities_vec(session, target_id, embedding):
|
|||||||
try:
|
try:
|
||||||
# First, try to update the existing row
|
# First, try to update the existing row
|
||||||
result = session.execute(
|
result = session.execute(
|
||||||
text(
|
text("UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"),
|
||||||
"UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"
|
|
||||||
),
|
|
||||||
{
|
{
|
||||||
"id": target_id,
|
"id": target_id,
|
||||||
"embedding": serialize_float32(embedding),
|
"embedding": serialize_float32(embedding),
|
||||||
@ -337,7 +376,7 @@ async def update_or_insert_entities_vec(session, target_id, embedding):
|
|||||||
"embedding": serialize_float32(embedding),
|
"embedding": serialize_float32(embedding),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error updating entities_vec: {e}")
|
print(f"Error updating entities_vec: {e}")
|
||||||
@ -379,7 +418,7 @@ def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata):
|
|||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error updating entities_fts: {e}")
|
print(f"Error updating entities_fts: {e}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user