diff --git a/memos/cmds/library.py b/memos/cmds/library.py index 3b2fcae..f0f52f6 100644 --- a/memos/cmds/library.py +++ b/memos/cmds/library.py @@ -1,35 +1,34 @@ +# Standard library imports +import time import math -import typer -import httpx +import re +import os +import threading import asyncio import logging -import threading - -from tqdm import tqdm +import logging.config from pathlib import Path -from tabulate import tabulate -from memos.config import settings -from magika import Magika from datetime import datetime from enum import Enum from typing import List, Tuple from functools import lru_cache - -import re -import os -import time -import psutil - from collections import defaultdict, deque + +# Third-party imports +import typer +import requests +from tqdm import tqdm +from tabulate import tabulate +import psutil from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler from concurrent.futures import ThreadPoolExecutor -from memos.models import recreate_fts_and_vec_tables +# Local imports +from memos.config import settings from memos.utils import get_image_metadata from memos.schemas import MetadataSource from memos.logging_config import LOGGING_CONFIG -import logging.config logging.config.dictConfig(LOGGING_CONFIG) @@ -37,7 +36,7 @@ logger = logging.getLogger(__name__) lib_app = typer.Typer() -file_detector = Magika() +file_detector = None IS_THUMBNAIL = "is_thumbnail" @@ -57,8 +56,20 @@ def format_timestamp(timestamp): return datetime.fromtimestamp(timestamp).replace(tzinfo=None).isoformat() +def init_file_detector(): + """Initialize the global file detector if not already initialized""" + global file_detector + if file_detector is None: + from magika import Magika + + file_detector = Magika() + return file_detector + + def get_file_type(file_path): - file_result = file_detector.identify_path(file_path) + """Get file type using lazy-loaded detector""" + detector = init_file_detector() + file_result = detector.identify_path(file_path) return file_result.output.ct_label, file_result.output.group @@ -86,7 +97,7 @@ def display_libraries(libraries): @lib_app.command("ls") def ls(): - response = httpx.get(f"{BASE_URL}/libraries") + response = requests.get(f"{BASE_URL}/libraries") libraries = response.json() display_libraries(libraries) @@ -105,11 +116,10 @@ def add(name: str, folders: List[str]): } ) - response = httpx.post( - f"{BASE_URL}/libraries", - json={"name": name, "folders": absolute_folders}, + response = requests.post( + f"{BASE_URL}/libraries", json={"name": name, "folders": absolute_folders} ) - if 200 <= response.status_code < 300: + if response.ok: print("Library created successfully") else: print(f"Failed to create library: {response.status_code} - {response.text}") @@ -129,7 +139,7 @@ def add_folder(library_id: int, folders: List[str]): } ) - response = httpx.post( + response = requests.post( f"{BASE_URL}/libraries/{library_id}/folders", json={"folders": absolute_folders}, ) @@ -143,7 +153,7 @@ def add_folder(library_id: int, folders: List[str]): @lib_app.command("show") def show(library_id: int): - response = httpx.get(f"{BASE_URL}/libraries/{library_id}") + response = requests.get(f"{BASE_URL}/libraries/{library_id}") if response.status_code == 200: library = response.json() display_libraries([library]) @@ -164,7 +174,8 @@ async def loop_files(library_id, folder, folder_path, force, plugins): added_file_count = 0 scanned_files = set() semaphore = asyncio.Semaphore(settings.batchsize) - async with httpx.AsyncClient(timeout=60) as client: + + with requests.Session() as session: tasks = [] for root, _, files in os.walk(folder_path): with tqdm(total=len(files), desc=f"Scanning {root}", leave=True) as pbar: @@ -186,12 +197,12 @@ async def loop_files(library_id, folder, folder_path, force, plugins): batch = candidate_files[i : i + batching] # Get batch of entities - get_response = await client.post( + get_response = session.post( f"{BASE_URL}/libraries/{library_id}/entities/by-filepaths", json=batch, ) - if get_response.status_code == 200: + if get_response.ok: existing_entities = get_response.json() else: print( @@ -305,7 +316,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins): ): tasks.append( update_entity( - client, + session, semaphore, plugins, new_entity, @@ -315,7 +326,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins): elif not is_thumbnail: # Ignore thumbnails tasks.append( add_entity( - client, semaphore, library_id, plugins, new_entity + session, semaphore, library_id, plugins, new_entity ) ) pbar.update(len(batch)) @@ -372,7 +383,7 @@ def scan( print("Error: You cannot specify both a path and folders at the same time.") return - response = httpx.get(f"{BASE_URL}/libraries/{library_id}") + response = requests.get(f"{BASE_URL}/libraries/{library_id}") if response.status_code != 200: print(f"Failed to retrieve library: {response.status_code} - {response.text}") return @@ -428,7 +439,7 @@ def scan( total=total_entities, desc="Checking for deleted files", leave=True ) as pbar2: while True: - existing_files_response = httpx.get( + existing_files_response = requests.get( f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities", params={"limit": limit, "offset": offset}, timeout=60, @@ -459,7 +470,7 @@ def scan( and existing_file["filepath"] not in scanned_files ): # File has been deleted - delete_response = httpx.delete( + delete_response = requests.delete( f"{BASE_URL}/libraries/{library_id}/entities/{existing_file['id']}" ) if 200 <= delete_response.status_code < 300: @@ -481,18 +492,18 @@ def scan( async def add_entity( - client: httpx.AsyncClient, + client: requests.Session, semaphore: asyncio.Semaphore, library_id, plugins, new_entity, -) -> Tuple[FileStatus, bool, httpx.Response]: +) -> Tuple[FileStatus, bool, requests.Response]: async with semaphore: MAX_RETRIES = 3 RETRY_DELAY = 2.0 for attempt in range(MAX_RETRIES): try: - post_response = await client.post( + post_response = client.post( f"{BASE_URL}/libraries/{library_id}/entities", json=new_entity, params={"plugins": plugins} if plugins else {}, @@ -518,18 +529,18 @@ async def add_entity( async def update_entity( - client: httpx.AsyncClient, + client: requests.Session, semaphore: asyncio.Semaphore, plugins, new_entity, existing_entity, -) -> Tuple[FileStatus, bool, httpx.Response]: +) -> Tuple[FileStatus, bool, requests.Response]: MAX_RETRIES = 3 RETRY_DELAY = 2.0 async with semaphore: for attempt in range(MAX_RETRIES): try: - update_response = await client.put( + update_response = client.put( f"{BASE_URL}/entities/{existing_entity['id']}", json=new_entity, params={ @@ -572,8 +583,10 @@ def reindex( ): print(f"Reindexing library {library_id}") + from memos.models import recreate_fts_and_vec_tables + # Get the library - response = httpx.get(f"{BASE_URL}/libraries/{library_id}") + response = requests.get(f"{BASE_URL}/libraries/{library_id}") if response.status_code != 200: print(f"Failed to get library: {response.status_code} - {response.text}") return @@ -594,7 +607,7 @@ def reindex( recreate_fts_and_vec_tables() print("FTS and vector tables have been recreated.") - with httpx.Client(timeout=60) as client: + with requests.Session() as client: total_entities = 0 # Get total entity count for all folders @@ -657,7 +670,7 @@ def reindex( async def check_and_index_entity(client, entity_id, entity_last_scan_at): try: - index_response = await client.get(f"{BASE_URL}/entities/{entity_id}/index") + index_response = client.get(f"{BASE_URL}/entities/{entity_id}/index") if index_response.status_code == 200: index_data = index_response.json() if index_data["last_scan_at"] is None: @@ -668,14 +681,14 @@ async def check_and_index_entity(client, entity_id, entity_last_scan_at): if index_last_scan_at >= entity_last_scan_at: return False # Index is up to date, no need to update return True # Index doesn't exist or needs update - except httpx.HTTPStatusError as e: + except requests.HTTPStatusError as e: if e.response.status_code == 404: return True # Index doesn't exist, need to create raise # Re-raise other HTTP errors async def index_batch(client, entity_ids): - index_response = await client.post( + index_response = client.post( f"{BASE_URL}/entities/batch-index", json=entity_ids, timeout=60, @@ -698,7 +711,7 @@ def sync( Sync a specific file with the library. """ # 1. Get library by id and check if it exists - response = httpx.get(f"{BASE_URL}/libraries/{library_id}") + response = requests.get(f"{BASE_URL}/libraries/{library_id}") if response.status_code != 200: typer.echo(f"Error: Library with id {library_id} not found.") raise typer.Exit(code=1) @@ -713,7 +726,7 @@ def sync( raise typer.Exit(code=1) # 2. Check if the file exists in the library - response = httpx.get( + response = requests.get( f"{BASE_URL}/libraries/{library_id}/entities/by-filepath", params={"filepath": str(file_path)}, ) @@ -782,7 +795,7 @@ def sync( != new_entity["file_last_modified_at"] or existing_entity["size"] != new_entity["size"] ): - update_response = httpx.put( + update_response = requests.put( f"{BASE_URL}/entities/{existing_entity['id']}", json=new_entity, params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, @@ -812,7 +825,7 @@ def sync( # Create new entity new_entity["folder_id"] = folder["id"] - create_response = httpx.post( + create_response = requests.post( f"{BASE_URL}/libraries/{library_id}/entities", json=new_entity, params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, @@ -1052,7 +1065,7 @@ def watch( logger.info(f"Watching library {library_id} for changes...") # Get the library - response = httpx.get(f"{BASE_URL}/libraries/{library_id}") + response = requests.get(f"{BASE_URL}/libraries/{library_id}") if response.status_code != 200: print(f"Error: Library with id {library_id} not found.") raise typer.Exit(code=1) diff --git a/memos/commands.py b/memos/commands.py index 74cdda2..5ca7924 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -1,24 +1,24 @@ +# Standard library imports import os import logging from pathlib import Path from datetime import datetime, timedelta from typing import List -import httpx +# Third-party imports +import requests # 替换 httpx import typer + +# Local imports from .config import settings, display_config -from .models import init_database -from .record import ( - run_screen_recorder_once, - run_screen_recorder, - load_previous_hashes, -) -import time + import sys import subprocess import platform -from .cmds.plugin import plugin_app, bind -from .cmds.library import lib_app, scan, reindex, watch + +from .cmds.plugin import plugin_app +from .cmds.library import lib_app + import psutil import signal from tabulate import tabulate @@ -35,16 +35,16 @@ logging.basicConfig( ) # Optionally, you can set the logging level for specific libraries -logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("requests").setLevel(logging.ERROR) logging.getLogger("typer").setLevel(logging.ERROR) def check_server_health(): """Check if the server is running and healthy.""" try: - response = httpx.get(f"{BASE_URL}/health", timeout=5) + response = requests.get(f"{BASE_URL}/health", timeout=5) return response.status_code == 200 - except httpx.RequestError: + except requests.RequestException: return False @@ -55,13 +55,11 @@ def callback(ctx: typer.Context): "scan", "reindex", "watch", - "ls", "create", "add-folder", "show", "sync", - "bind", "unbind", ] @@ -79,9 +77,12 @@ app.add_typer(lib_app, name="lib", callback=callback) @app.command() def serve(): """Run the server after initializing if necessary.""" + from .models import init_database + db_success = init_database() if db_success: from .server import run_server + run_server() else: print("Server initialization failed. Unable to start the server.") @@ -90,6 +91,8 @@ def serve(): @app.command() def init(): """Initialize the database.""" + from .models import init_database + db_success = init_database() if db_success: print("Initialization completed successfully.") @@ -102,7 +105,8 @@ def get_or_create_default_library(): Get the default library or create it if it doesn't exist. Ensure the library has at least one folder. """ - response = httpx.get(f"{BASE_URL}/libraries") + from .cmds.plugin import bind + response = requests.get(f"{BASE_URL}/libraries") if response.status_code != 200: print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") return None @@ -114,7 +118,7 @@ def get_or_create_default_library(): if not default_library: # Create the default library if it doesn't exist - response = httpx.post( + response = requests.post( f"{BASE_URL}/libraries", json={"name": settings.default_library, "folders": []}, ) @@ -138,7 +142,7 @@ def get_or_create_default_library(): screenshots_dir.stat().st_mtime ).isoformat(), } - response = httpx.post( + response = requests.post( f"{BASE_URL}/libraries/{default_library['id']}/folders", json={"folders": [folder]}, ) @@ -162,13 +166,16 @@ def scan_default_library( """ Scan the screenshots directory and add it to the library if empty. """ + from .cmds.library import scan + default_library = get_or_create_default_library() if not default_library: return - # Scan the library print(f"Scanning library: {default_library['name']}") - scan(default_library["id"], path=path, plugins=plugins, folders=folders, force=force) + scan( + default_library["id"], path=path, plugins=plugins, folders=folders, force=force + ) @app.command("reindex") @@ -180,8 +187,10 @@ def reindex_default_library( """ Reindex the default library for memos. """ + from .cmds.library import reindex + # Get the default library - response = httpx.get(f"{BASE_URL}/libraries") + response = requests.get(f"{BASE_URL}/libraries") if response.status_code != 200: print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") return @@ -209,6 +218,12 @@ def record( """ Record screenshots of the screen. """ + from .record import ( + run_screen_recorder_once, + run_screen_recorder, + load_previous_hashes, + ) + base_dir = ( os.path.expanduser(base_dir) if base_dir else settings.resolved_screenshots_dir ) @@ -242,6 +257,8 @@ def watch_default_library( """ Watch the default library for file changes and sync automatically. """ + from .cmds.library import watch + default_library = get_or_create_default_library() if not default_library: return @@ -343,7 +360,7 @@ def generate_plist(): python_dir = os.path.dirname(get_python_path()) log_dir = memos_dir / "logs" log_dir.mkdir(parents=True, exist_ok=True) - + plist_content = f""" diff --git a/memos/embedding.py b/memos/embedding.py index ac348f9..73b5f19 100644 --- a/memos/embedding.py +++ b/memos/embedding.py @@ -1,12 +1,8 @@ from typing import List -from sentence_transformers import SentenceTransformer -import torch import numpy as np -from modelscope import snapshot_download from .config import settings import logging import httpx -import asyncio # Configure logger logging.basicConfig(level=logging.INFO) @@ -18,6 +14,9 @@ device = None def init_embedding_model(): + import torch + from sentence_transformers import SentenceTransformer + global model, device if torch.cuda.is_available(): device = torch.device("cuda") @@ -27,6 +26,7 @@ def init_embedding_model(): device = torch.device("cpu") if settings.embedding.use_modelscope: + from modelscope import snapshot_download model_dir = snapshot_download(settings.embedding.model) logger.info(f"Model downloaded from ModelScope to: {model_dir}") else: diff --git a/memos/models.py b/memos/models.py index 51782d9..e5b15e5 100644 --- a/memos/models.py +++ b/memos/models.py @@ -475,8 +475,7 @@ def update_fts_and_vec_sync(mapper, connection, entity: EntityModel): thread.start() thread.join() - -# Replace the old event listener with the new sync version +# Add event listeners for EntityModel event.listen(EntityModel, "after_insert", update_fts_and_vec_sync) event.listen(EntityModel, "after_update", update_fts_and_vec_sync) -event.listen(EntityModel, "after_delete", delete_fts_and_vec) +event.listen(EntityModel, "after_delete", delete_fts_and_vec) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 27886ae..3912078 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "mss", "sqlite_vec", "watchdog", + "requests", ] [project.urls]