From f5246227281d5a2681ad79137ef890ded04a644c Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Wed, 6 Nov 2024 16:41:10 +0800 Subject: [PATCH] refact: add batch size for scan cmd --- memos/cmds/library.py | 7 ++++--- memos/commands.py | 19 ++++++++++++++----- memos/config.py | 2 -- memos/default_config.yaml | 1 - 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/memos/cmds/library.py b/memos/cmds/library.py index e84bd03..aece10b 100644 --- a/memos/cmds/library.py +++ b/memos/cmds/library.py @@ -169,11 +169,11 @@ def is_temp_file(filename): ) -async def loop_files(library_id, folder, folder_path, force, plugins): +async def loop_files(library_id, folder, folder_path, force, plugins, batch_size): updated_file_count = 0 added_file_count = 0 scanned_files = set() - semaphore = asyncio.Semaphore(settings.batchsize) + semaphore = asyncio.Semaphore(batch_size) async with httpx.AsyncClient(timeout=60) as client: tasks = [] @@ -377,6 +377,7 @@ def scan( force: bool = False, plugins: List[int] = typer.Option(None, "--plugin", "-p"), folders: List[int] = typer.Option(None, "--folder", "-f"), + batch_size: int = typer.Option(1, "--batch-size", "-bs", help="Batch size for processing files"), ): # Check if both path and folders are provided if path and folders: @@ -426,7 +427,7 @@ def scan( continue added_file_count, updated_file_count, scanned_files = asyncio.run( - loop_files(library_id, folder, folder_path, force, plugins) + loop_files(library_id, folder, folder_path, force, plugins, batch_size) ) total_files_added += added_file_count total_files_updated += updated_file_count diff --git a/memos/commands.py b/memos/commands.py index 24aa574..d0320f8 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -92,7 +92,7 @@ def serve(): def init(): """Initialize the database.""" from .models import init_database - + db_success = init_database() if db_success: print("Initialization completed successfully.") @@ -106,6 +106,7 @@ def get_or_create_default_library(): Ensure the library has at least one folder. """ from .cmds.plugin import bind + response = httpx.get(f"{BASE_URL}/libraries") if response.status_code != 200: print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") @@ -162,19 +163,27 @@ def scan_default_library( path: str = typer.Argument(None, help="Path to scan within the library"), plugins: List[int] = typer.Option(None, "--plugin", "-p"), folders: List[int] = typer.Option(None, "--folder", "-f"), + batch_size: int = typer.Option( + 1, "--batch-size", "-bs", help="Batch size for processing files" + ), ): """ Scan the screenshots directory and add it to the library if empty. """ from .cmds.library import scan - + default_library = get_or_create_default_library() if not default_library: return print(f"Scanning library: {default_library['name']}") scan( - default_library["id"], path=path, plugins=plugins, folders=folders, force=force + default_library["id"], + path=path, + plugins=plugins, + folders=folders, + force=force, + batch_size=batch_size, ) @@ -188,7 +197,7 @@ def reindex_default_library( Reindex the default library for memos. """ from .cmds.library import reindex - + # Get the default library response = httpx.get(f"{BASE_URL}/libraries") if response.status_code != 200: @@ -258,7 +267,7 @@ def watch_default_library( Watch the default library for file changes and sync automatically. """ from .cmds.library import watch - + default_library = get_or_create_default_library() if not default_library: return diff --git a/memos/config.py b/memos/config.py index fc78fcf..368984c 100644 --- a/memos/config.py +++ b/memos/config.py @@ -71,8 +71,6 @@ class Settings(BaseSettings): # Embedding settings embedding: EmbeddingSettings = EmbeddingSettings() - batchsize: int = 1 - auth_username: str = "admin" auth_password: SecretStr = SecretStr("changeme") diff --git a/memos/default_config.yaml b/memos/default_config.yaml index 4af0058..333dacd 100644 --- a/memos/default_config.yaml +++ b/memos/default_config.yaml @@ -6,7 +6,6 @@ screenshots_dir: screenshots server_host: 0.0.0.0 server_port: 8839 -batchsize: 1 auth_username: admin auth_password: changeme default_plugins: