diff --git a/memos/cmds/library.py b/memos/cmds/library.py index d69cc30..628ae96 100644 --- a/memos/cmds/library.py +++ b/memos/cmds/library.py @@ -22,6 +22,7 @@ from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler from concurrent.futures import ThreadPoolExecutor +from memos.models import recreate_fts_and_vec_tables from memos.utils import get_image_metadata from memos.schemas import MetadataSource from memos.logging_config import LOGGING_CONFIG @@ -536,6 +537,7 @@ async def update_entity( def reindex( library_id: int, folders: List[int] = typer.Option(None, "--folder", "-f"), + force: bool = typer.Option(False, "--force", help="Force recreate FTS and vector tables before reindexing"), ): print(f"Reindexing library {library_id}") @@ -556,9 +558,27 @@ def reindex( else: library_folders = library["folders"] - def process_folders(): - with httpx.Client(timeout=60) as client: - # Iterate through folders + if force: + print("Force flag is set. Recreating FTS and vector tables...") + recreate_fts_and_vec_tables() + print("FTS and vector tables have been recreated.") + + with httpx.Client(timeout=60) as client: + total_entities = 0 + + # Get total entity count for all folders + for folder in library_folders: + response = client.get( + f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities", + params={"limit": 1, "offset": 0}, + ) + if response.status_code == 200: + total_entities += int(response.headers.get("X-Total-Count", 0)) + else: + print(f"Failed to get entity count for folder {folder['id']}: {response.status_code} - {response.text}") + + # Now process entities with a progress bar + with tqdm(total=total_entities, desc="Reindexing entities") as pbar: for folder in library_folders: print(f"Processing folder: {folder['id']}") @@ -587,15 +607,14 @@ def reindex( if update_response.status_code != 204: print(f"Failed to update last_scan_at for entity {entity['id']}: {update_response.status_code} - {update_response.text}") else: - print(f"Updated last_scan_at for entity {entity['id']}") - - scanned_entities.add(entity["id"]) + scanned_entities.add(entity["id"]) + + pbar.update(1) offset += limit - process_folders() print(f"Reindexing completed for library {library_id}") - + async def check_and_index_entity(client, entity_id, entity_last_scan_at): try: diff --git a/memos/commands.py b/memos/commands.py index 39c4993..a37b2a0 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta import httpx import typer -from .config import settings +from .config import settings, display_config from .models import init_database from .initialize_typesense import init_typesense from .record import ( @@ -168,7 +168,9 @@ def typsense_index_default_library( @app.command("reindex") -def reindex_default_library(): +def reindex_default_library( + force: bool = typer.Option(False, "--force", help="Force recreate FTS and vector tables before reindexing") +): """ Reindex the default library for memos. """ @@ -189,7 +191,7 @@ def reindex_default_library(): # Reindex the library print(f"Reindexing library: {default_library['name']}") - reindex(default_library["id"]) + reindex(default_library["id"], force=force) @app.command("record") @@ -485,14 +487,14 @@ def ps(): create_time = datetime.fromtimestamp(process.info['create_time']).strftime('%Y-%m-%d %H:%M:%S') running_time = str(timedelta(seconds=int(time.time() - process.info['create_time']))) table_data.append([ - service.capitalize(), + service, "Running", process.info['pid'], create_time, running_time ]) else: - table_data.append([service.capitalize(), "Not Running", "-", "-", "-"]) + table_data.append([service, "Not Running", "-", "-", "-"]) headers = ["Name", "Status", "PID", "Started At", "Running For"] typer.echo(tabulate(table_data, headers=headers, tablefmt="plain")) @@ -559,5 +561,11 @@ def start(): typer.echo("Unsupported operating system.") +@app.command() +def config(): + """Show current configuration settings""" + display_config() + + if __name__ == "__main__": app() \ No newline at end of file diff --git a/memos/config.py b/memos/config.py index d6d1ef7..490e045 100644 --- a/memos/config.py +++ b/memos/config.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, SecretStr import yaml from collections import OrderedDict import io +import typer class VLMSettings(BaseModel): @@ -22,7 +23,7 @@ class VLMSettings(BaseModel): # some vlm models do not support webp force_jpeg: bool = True # prompt for vlm to extract caption - prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等" + prompt: str = "请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等" class OCRSettings(BaseModel): @@ -118,11 +119,13 @@ yaml.add_representer(OrderedDict, dict_representer) def secret_str_representer(dumper, data): return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value()) + # Custom constructor for SecretStr def secret_str_constructor(loader, node): value = loader.construct_scalar(node) return SecretStr(value) + # Register the representer and constructor only for specific fields yaml.add_representer(SecretStr, secret_str_representer) @@ -132,22 +135,25 @@ def create_default_config(): if not config_path.exists(): settings = Settings() os.makedirs(config_path.parent, exist_ok=True) - + # 将设置转换为字典并确保顺序 settings_dict = settings.model_dump() ordered_settings = OrderedDict( (key, settings_dict[key]) for key in settings.model_fields.keys() ) - + # 使用 io.StringIO 作为中间步骤 with io.StringIO() as string_buffer: - yaml.dump(ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper) + yaml.dump( + ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper + ) yaml_content = string_buffer.getvalue() - + # 将内容写入文件,确保使用 UTF-8 编码 with open(config_path, "w", encoding="utf-8") as f: f.write(yaml_content) + # Create default config if it doesn't exist create_default_config() @@ -164,3 +170,36 @@ TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name # Function to get the database path from environment variable or default def get_database_path(): return settings.database_path + + +def format_value(value): + if isinstance( + value, (VLMSettings, OCRSettings, EmbeddingSettings, TypesenseSettings) + ): + return ( + "{\n" + + "\n".join(f" {k}: {v}" for k, v in value.model_dump().items()) + + "\n }" + ) + elif isinstance(value, (list, tuple)): + return f"[{', '.join(map(str, value))}]" + elif isinstance(value, SecretStr): + return "********" # Hide the actual value of SecretStr + else: + return str(value) + + +def display_config(): + settings = Settings() + config_dict = settings.model_dump() + max_key_length = max(len(key) for key in config_dict.keys()) + + typer.echo("Current configuration settings:") + for key, value in config_dict.items(): + formatted_value = format_value(value) + if "\n" in formatted_value: + typer.echo(f"{key}:") + for line in formatted_value.split("\n"): + typer.echo(f" {line}") + else: + typer.echo(f"{key.ljust(max_key_length)} : {formatted_value}") diff --git a/memos/embedding.py b/memos/embedding.py index e676dd3..ac348f9 100644 --- a/memos/embedding.py +++ b/memos/embedding.py @@ -33,7 +33,7 @@ def init_embedding_model(): model_dir = settings.embedding.model logger.info(f"Using model: {model_dir}") - model = SentenceTransformer(model_dir, trust_remote_code=True, truncate_dim=768) + model = SentenceTransformer(model_dir, trust_remote_code=True) model.to(device) logger.info(f"Embedding model initialized on device: {device}") diff --git a/memos/models.py b/memos/models.py index 6543bca..df4ce79 100644 --- a/memos/models.py +++ b/memos/models.py @@ -9,7 +9,6 @@ from sqlalchemy import ( func, Index, event, - DDL, ) from datetime import datetime from sqlalchemy.orm import relationship, DeclarativeBase, Mapped, mapped_column, Session @@ -18,10 +17,8 @@ from .schemas import MetadataSource, MetadataType, FolderType from sqlalchemy.exc import OperationalError from sqlalchemy.orm import sessionmaker from .config import get_database_path, settings -import numpy as np from sqlalchemy import text import sqlite_vec -import os import sys from pathlib import Path import json @@ -202,6 +199,50 @@ def load_extension(dbapi_conn, connection_record): dbapi_conn.execute("PRAGMA journal_mode=WAL") +def recreate_fts_and_vec_tables(): + """Recreate the entities_fts and entities_vec tables without repopulating data.""" + db_path = get_database_path() + engine = create_engine(f"sqlite:///{db_path}") + event.listen(engine, "connect", load_extension) + + Session = sessionmaker(bind=engine) + + with Session() as session: + try: + # Drop existing tables + session.execute(text("DROP TABLE IF EXISTS entities_fts")) + session.execute(text("DROP TABLE IF EXISTS entities_vec")) + + # Recreate entities_fts table + session.execute( + text( + """ + CREATE VIRTUAL TABLE entities_fts USING fts5( + id, filepath, tags, metadata, + tokenize = 'simple 0' + ) + """ + ) + ) + + # Recreate entities_vec table + session.execute( + text( + f""" + CREATE VIRTUAL TABLE entities_vec USING vec0( + embedding float[{settings.embedding.num_dim}] + ) + """ + ) + ) + + session.commit() + print("Successfully recreated entities_fts and entities_vec tables.") + except Exception as e: + session.rollback() + print(f"Error recreating tables: {e}") + + def init_database(): """Initialize the database.""" db_path = get_database_path() @@ -214,25 +255,25 @@ def init_database(): Base.metadata.create_all(engine) print(f"Database initialized successfully at {db_path}") + # Create FTS and Vec tables with engine.connect() as conn: conn.execute( - DDL( + text( """ - CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5( - id, filepath, tags, metadata, - tokenize = 'simple 0' - ) + CREATE VIRTUAL TABLE IF NOT EXISTS entities_fts USING fts5( + id, filepath, tags, metadata, + tokenize = 'simple 0' + ) """ ) ) - with engine.connect() as conn: conn.execute( - DDL( + text( f""" - CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0( - embedding float[{settings.embedding.num_dim}] - ) + CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0( + embedding float[{settings.embedding.num_dim}] + ) """ ) ) @@ -317,9 +358,7 @@ async def update_or_insert_entities_vec(session, target_id, embedding): try: # First, try to update the existing row result = session.execute( - text( - "UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id" - ), + text("UPDATE entities_vec SET embedding = :embedding WHERE rowid = :id"), { "id": target_id, "embedding": serialize_float32(embedding), @@ -337,7 +376,7 @@ async def update_or_insert_entities_vec(session, target_id, embedding): "embedding": serialize_float32(embedding), }, ) - + session.commit() except Exception as e: print(f"Error updating entities_vec: {e}") @@ -379,7 +418,7 @@ def update_or_insert_entities_fts(session, target_id, filepath, tags, metadata): "metadata": metadata, }, ) - + session.commit() except Exception as e: print(f"Error updating entities_fts: {e}")