refact: add batch size for scan cmd

This commit is contained in:
arkohut 2024-11-06 16:41:10 +08:00
parent b387807ee4
commit f524622728
4 changed files with 18 additions and 11 deletions

View File

@ -169,11 +169,11 @@ def is_temp_file(filename):
) )
async def loop_files(library_id, folder, folder_path, force, plugins): async def loop_files(library_id, folder, folder_path, force, plugins, batch_size):
updated_file_count = 0 updated_file_count = 0
added_file_count = 0 added_file_count = 0
scanned_files = set() scanned_files = set()
semaphore = asyncio.Semaphore(settings.batchsize) semaphore = asyncio.Semaphore(batch_size)
async with httpx.AsyncClient(timeout=60) as client: async with httpx.AsyncClient(timeout=60) as client:
tasks = [] tasks = []
@ -377,6 +377,7 @@ def scan(
force: bool = False, force: bool = False,
plugins: List[int] = typer.Option(None, "--plugin", "-p"), plugins: List[int] = typer.Option(None, "--plugin", "-p"),
folders: List[int] = typer.Option(None, "--folder", "-f"), folders: List[int] = typer.Option(None, "--folder", "-f"),
batch_size: int = typer.Option(1, "--batch-size", "-bs", help="Batch size for processing files"),
): ):
# Check if both path and folders are provided # Check if both path and folders are provided
if path and folders: if path and folders:
@ -426,7 +427,7 @@ def scan(
continue continue
added_file_count, updated_file_count, scanned_files = asyncio.run( added_file_count, updated_file_count, scanned_files = asyncio.run(
loop_files(library_id, folder, folder_path, force, plugins) loop_files(library_id, folder, folder_path, force, plugins, batch_size)
) )
total_files_added += added_file_count total_files_added += added_file_count
total_files_updated += updated_file_count total_files_updated += updated_file_count

View File

@ -92,7 +92,7 @@ def serve():
def init(): def init():
"""Initialize the database.""" """Initialize the database."""
from .models import init_database from .models import init_database
db_success = init_database() db_success = init_database()
if db_success: if db_success:
print("Initialization completed successfully.") print("Initialization completed successfully.")
@ -106,6 +106,7 @@ def get_or_create_default_library():
Ensure the library has at least one folder. Ensure the library has at least one folder.
""" """
from .cmds.plugin import bind from .cmds.plugin import bind
response = httpx.get(f"{BASE_URL}/libraries") response = httpx.get(f"{BASE_URL}/libraries")
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
@ -162,19 +163,27 @@ def scan_default_library(
path: str = typer.Argument(None, help="Path to scan within the library"), path: str = typer.Argument(None, help="Path to scan within the library"),
plugins: List[int] = typer.Option(None, "--plugin", "-p"), plugins: List[int] = typer.Option(None, "--plugin", "-p"),
folders: List[int] = typer.Option(None, "--folder", "-f"), folders: List[int] = typer.Option(None, "--folder", "-f"),
batch_size: int = typer.Option(
1, "--batch-size", "-bs", help="Batch size for processing files"
),
): ):
""" """
Scan the screenshots directory and add it to the library if empty. Scan the screenshots directory and add it to the library if empty.
""" """
from .cmds.library import scan from .cmds.library import scan
default_library = get_or_create_default_library() default_library = get_or_create_default_library()
if not default_library: if not default_library:
return return
print(f"Scanning library: {default_library['name']}") print(f"Scanning library: {default_library['name']}")
scan( scan(
default_library["id"], path=path, plugins=plugins, folders=folders, force=force default_library["id"],
path=path,
plugins=plugins,
folders=folders,
force=force,
batch_size=batch_size,
) )
@ -188,7 +197,7 @@ def reindex_default_library(
Reindex the default library for memos. Reindex the default library for memos.
""" """
from .cmds.library import reindex from .cmds.library import reindex
# Get the default library # Get the default library
response = httpx.get(f"{BASE_URL}/libraries") response = httpx.get(f"{BASE_URL}/libraries")
if response.status_code != 200: if response.status_code != 200:
@ -258,7 +267,7 @@ def watch_default_library(
Watch the default library for file changes and sync automatically. Watch the default library for file changes and sync automatically.
""" """
from .cmds.library import watch from .cmds.library import watch
default_library = get_or_create_default_library() default_library = get_or_create_default_library()
if not default_library: if not default_library:
return return

View File

@ -71,8 +71,6 @@ class Settings(BaseSettings):
# Embedding settings # Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings() embedding: EmbeddingSettings = EmbeddingSettings()
batchsize: int = 1
auth_username: str = "admin" auth_username: str = "admin"
auth_password: SecretStr = SecretStr("changeme") auth_password: SecretStr = SecretStr("changeme")

View File

@ -6,7 +6,6 @@ screenshots_dir: screenshots
server_host: 0.0.0.0 server_host: 0.0.0.0
server_port: 8839 server_port: 8839
batchsize: 1
auth_username: admin auth_username: admin
auth_password: changeme auth_password: changeme
default_plugins: default_plugins: