refact: add lots of dynamic import to speed up cli

This commit is contained in:
arkohut 2024-11-06 13:13:59 +08:00
parent ea8301745f
commit 05c168142a
5 changed files with 108 additions and 78 deletions

View File

@ -1,35 +1,34 @@
# Standard library imports
import time
import math import math
import typer import re
import httpx import os
import threading
import asyncio import asyncio
import logging import logging
import threading import logging.config
from tqdm import tqdm
from pathlib import Path from pathlib import Path
from tabulate import tabulate
from memos.config import settings
from magika import Magika
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import List, Tuple from typing import List, Tuple
from functools import lru_cache from functools import lru_cache
import re
import os
import time
import psutil
from collections import defaultdict, deque from collections import defaultdict, deque
# Third-party imports
import typer
import requests
from tqdm import tqdm
from tabulate import tabulate
import psutil
from watchdog.observers import Observer from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler from watchdog.events import FileSystemEventHandler
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from memos.models import recreate_fts_and_vec_tables # Local imports
from memos.config import settings
from memos.utils import get_image_metadata from memos.utils import get_image_metadata
from memos.schemas import MetadataSource from memos.schemas import MetadataSource
from memos.logging_config import LOGGING_CONFIG from memos.logging_config import LOGGING_CONFIG
import logging.config
logging.config.dictConfig(LOGGING_CONFIG) logging.config.dictConfig(LOGGING_CONFIG)
@ -37,7 +36,7 @@ logger = logging.getLogger(__name__)
lib_app = typer.Typer() lib_app = typer.Typer()
file_detector = Magika() file_detector = None
IS_THUMBNAIL = "is_thumbnail" IS_THUMBNAIL = "is_thumbnail"
@ -57,8 +56,20 @@ def format_timestamp(timestamp):
return datetime.fromtimestamp(timestamp).replace(tzinfo=None).isoformat() return datetime.fromtimestamp(timestamp).replace(tzinfo=None).isoformat()
def init_file_detector():
"""Initialize the global file detector if not already initialized"""
global file_detector
if file_detector is None:
from magika import Magika
file_detector = Magika()
return file_detector
def get_file_type(file_path): def get_file_type(file_path):
file_result = file_detector.identify_path(file_path) """Get file type using lazy-loaded detector"""
detector = init_file_detector()
file_result = detector.identify_path(file_path)
return file_result.output.ct_label, file_result.output.group return file_result.output.ct_label, file_result.output.group
@ -86,7 +97,7 @@ def display_libraries(libraries):
@lib_app.command("ls") @lib_app.command("ls")
def ls(): def ls():
response = httpx.get(f"{BASE_URL}/libraries") response = requests.get(f"{BASE_URL}/libraries")
libraries = response.json() libraries = response.json()
display_libraries(libraries) display_libraries(libraries)
@ -105,11 +116,10 @@ def add(name: str, folders: List[str]):
} }
) )
response = httpx.post( response = requests.post(
f"{BASE_URL}/libraries", f"{BASE_URL}/libraries", json={"name": name, "folders": absolute_folders}
json={"name": name, "folders": absolute_folders},
) )
if 200 <= response.status_code < 300: if response.ok:
print("Library created successfully") print("Library created successfully")
else: else:
print(f"Failed to create library: {response.status_code} - {response.text}") print(f"Failed to create library: {response.status_code} - {response.text}")
@ -129,7 +139,7 @@ def add_folder(library_id: int, folders: List[str]):
} }
) )
response = httpx.post( response = requests.post(
f"{BASE_URL}/libraries/{library_id}/folders", f"{BASE_URL}/libraries/{library_id}/folders",
json={"folders": absolute_folders}, json={"folders": absolute_folders},
) )
@ -143,7 +153,7 @@ def add_folder(library_id: int, folders: List[str]):
@lib_app.command("show") @lib_app.command("show")
def show(library_id: int): def show(library_id: int):
response = httpx.get(f"{BASE_URL}/libraries/{library_id}") response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code == 200: if response.status_code == 200:
library = response.json() library = response.json()
display_libraries([library]) display_libraries([library])
@ -164,7 +174,8 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
added_file_count = 0 added_file_count = 0
scanned_files = set() scanned_files = set()
semaphore = asyncio.Semaphore(settings.batchsize) semaphore = asyncio.Semaphore(settings.batchsize)
async with httpx.AsyncClient(timeout=60) as client:
with requests.Session() as session:
tasks = [] tasks = []
for root, _, files in os.walk(folder_path): for root, _, files in os.walk(folder_path):
with tqdm(total=len(files), desc=f"Scanning {root}", leave=True) as pbar: with tqdm(total=len(files), desc=f"Scanning {root}", leave=True) as pbar:
@ -186,12 +197,12 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
batch = candidate_files[i : i + batching] batch = candidate_files[i : i + batching]
# Get batch of entities # Get batch of entities
get_response = await client.post( get_response = session.post(
f"{BASE_URL}/libraries/{library_id}/entities/by-filepaths", f"{BASE_URL}/libraries/{library_id}/entities/by-filepaths",
json=batch, json=batch,
) )
if get_response.status_code == 200: if get_response.ok:
existing_entities = get_response.json() existing_entities = get_response.json()
else: else:
print( print(
@ -305,7 +316,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
): ):
tasks.append( tasks.append(
update_entity( update_entity(
client, session,
semaphore, semaphore,
plugins, plugins,
new_entity, new_entity,
@ -315,7 +326,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
elif not is_thumbnail: # Ignore thumbnails elif not is_thumbnail: # Ignore thumbnails
tasks.append( tasks.append(
add_entity( add_entity(
client, semaphore, library_id, plugins, new_entity session, semaphore, library_id, plugins, new_entity
) )
) )
pbar.update(len(batch)) pbar.update(len(batch))
@ -372,7 +383,7 @@ def scan(
print("Error: You cannot specify both a path and folders at the same time.") print("Error: You cannot specify both a path and folders at the same time.")
return return
response = httpx.get(f"{BASE_URL}/libraries/{library_id}") response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to retrieve library: {response.status_code} - {response.text}") print(f"Failed to retrieve library: {response.status_code} - {response.text}")
return return
@ -428,7 +439,7 @@ def scan(
total=total_entities, desc="Checking for deleted files", leave=True total=total_entities, desc="Checking for deleted files", leave=True
) as pbar2: ) as pbar2:
while True: while True:
existing_files_response = httpx.get( existing_files_response = requests.get(
f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities", f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities",
params={"limit": limit, "offset": offset}, params={"limit": limit, "offset": offset},
timeout=60, timeout=60,
@ -459,7 +470,7 @@ def scan(
and existing_file["filepath"] not in scanned_files and existing_file["filepath"] not in scanned_files
): ):
# File has been deleted # File has been deleted
delete_response = httpx.delete( delete_response = requests.delete(
f"{BASE_URL}/libraries/{library_id}/entities/{existing_file['id']}" f"{BASE_URL}/libraries/{library_id}/entities/{existing_file['id']}"
) )
if 200 <= delete_response.status_code < 300: if 200 <= delete_response.status_code < 300:
@ -481,18 +492,18 @@ def scan(
async def add_entity( async def add_entity(
client: httpx.AsyncClient, client: requests.Session,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
library_id, library_id,
plugins, plugins,
new_entity, new_entity,
) -> Tuple[FileStatus, bool, httpx.Response]: ) -> Tuple[FileStatus, bool, requests.Response]:
async with semaphore: async with semaphore:
MAX_RETRIES = 3 MAX_RETRIES = 3
RETRY_DELAY = 2.0 RETRY_DELAY = 2.0
for attempt in range(MAX_RETRIES): for attempt in range(MAX_RETRIES):
try: try:
post_response = await client.post( post_response = client.post(
f"{BASE_URL}/libraries/{library_id}/entities", f"{BASE_URL}/libraries/{library_id}/entities",
json=new_entity, json=new_entity,
params={"plugins": plugins} if plugins else {}, params={"plugins": plugins} if plugins else {},
@ -518,18 +529,18 @@ async def add_entity(
async def update_entity( async def update_entity(
client: httpx.AsyncClient, client: requests.Session,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
plugins, plugins,
new_entity, new_entity,
existing_entity, existing_entity,
) -> Tuple[FileStatus, bool, httpx.Response]: ) -> Tuple[FileStatus, bool, requests.Response]:
MAX_RETRIES = 3 MAX_RETRIES = 3
RETRY_DELAY = 2.0 RETRY_DELAY = 2.0
async with semaphore: async with semaphore:
for attempt in range(MAX_RETRIES): for attempt in range(MAX_RETRIES):
try: try:
update_response = await client.put( update_response = client.put(
f"{BASE_URL}/entities/{existing_entity['id']}", f"{BASE_URL}/entities/{existing_entity['id']}",
json=new_entity, json=new_entity,
params={ params={
@ -572,8 +583,10 @@ def reindex(
): ):
print(f"Reindexing library {library_id}") print(f"Reindexing library {library_id}")
from memos.models import recreate_fts_and_vec_tables
# Get the library # Get the library
response = httpx.get(f"{BASE_URL}/libraries/{library_id}") response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to get library: {response.status_code} - {response.text}") print(f"Failed to get library: {response.status_code} - {response.text}")
return return
@ -594,7 +607,7 @@ def reindex(
recreate_fts_and_vec_tables() recreate_fts_and_vec_tables()
print("FTS and vector tables have been recreated.") print("FTS and vector tables have been recreated.")
with httpx.Client(timeout=60) as client: with requests.Session() as client:
total_entities = 0 total_entities = 0
# Get total entity count for all folders # Get total entity count for all folders
@ -657,7 +670,7 @@ def reindex(
async def check_and_index_entity(client, entity_id, entity_last_scan_at): async def check_and_index_entity(client, entity_id, entity_last_scan_at):
try: try:
index_response = await client.get(f"{BASE_URL}/entities/{entity_id}/index") index_response = client.get(f"{BASE_URL}/entities/{entity_id}/index")
if index_response.status_code == 200: if index_response.status_code == 200:
index_data = index_response.json() index_data = index_response.json()
if index_data["last_scan_at"] is None: if index_data["last_scan_at"] is None:
@ -668,14 +681,14 @@ async def check_and_index_entity(client, entity_id, entity_last_scan_at):
if index_last_scan_at >= entity_last_scan_at: if index_last_scan_at >= entity_last_scan_at:
return False # Index is up to date, no need to update return False # Index is up to date, no need to update
return True # Index doesn't exist or needs update return True # Index doesn't exist or needs update
except httpx.HTTPStatusError as e: except requests.HTTPStatusError as e:
if e.response.status_code == 404: if e.response.status_code == 404:
return True # Index doesn't exist, need to create return True # Index doesn't exist, need to create
raise # Re-raise other HTTP errors raise # Re-raise other HTTP errors
async def index_batch(client, entity_ids): async def index_batch(client, entity_ids):
index_response = await client.post( index_response = client.post(
f"{BASE_URL}/entities/batch-index", f"{BASE_URL}/entities/batch-index",
json=entity_ids, json=entity_ids,
timeout=60, timeout=60,
@ -698,7 +711,7 @@ def sync(
Sync a specific file with the library. Sync a specific file with the library.
""" """
# 1. Get library by id and check if it exists # 1. Get library by id and check if it exists
response = httpx.get(f"{BASE_URL}/libraries/{library_id}") response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200: if response.status_code != 200:
typer.echo(f"Error: Library with id {library_id} not found.") typer.echo(f"Error: Library with id {library_id} not found.")
raise typer.Exit(code=1) raise typer.Exit(code=1)
@ -713,7 +726,7 @@ def sync(
raise typer.Exit(code=1) raise typer.Exit(code=1)
# 2. Check if the file exists in the library # 2. Check if the file exists in the library
response = httpx.get( response = requests.get(
f"{BASE_URL}/libraries/{library_id}/entities/by-filepath", f"{BASE_URL}/libraries/{library_id}/entities/by-filepath",
params={"filepath": str(file_path)}, params={"filepath": str(file_path)},
) )
@ -782,7 +795,7 @@ def sync(
!= new_entity["file_last_modified_at"] != new_entity["file_last_modified_at"]
or existing_entity["size"] != new_entity["size"] or existing_entity["size"] != new_entity["size"]
): ):
update_response = httpx.put( update_response = requests.put(
f"{BASE_URL}/entities/{existing_entity['id']}", f"{BASE_URL}/entities/{existing_entity['id']}",
json=new_entity, json=new_entity,
params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
@ -812,7 +825,7 @@ def sync(
# Create new entity # Create new entity
new_entity["folder_id"] = folder["id"] new_entity["folder_id"] = folder["id"]
create_response = httpx.post( create_response = requests.post(
f"{BASE_URL}/libraries/{library_id}/entities", f"{BASE_URL}/libraries/{library_id}/entities",
json=new_entity, json=new_entity,
params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
@ -1052,7 +1065,7 @@ def watch(
logger.info(f"Watching library {library_id} for changes...") logger.info(f"Watching library {library_id} for changes...")
# Get the library # Get the library
response = httpx.get(f"{BASE_URL}/libraries/{library_id}") response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200: if response.status_code != 200:
print(f"Error: Library with id {library_id} not found.") print(f"Error: Library with id {library_id} not found.")
raise typer.Exit(code=1) raise typer.Exit(code=1)

View File

@ -1,24 +1,24 @@
# Standard library imports
import os import os
import logging import logging
from pathlib import Path from pathlib import Path
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List from typing import List
import httpx # Third-party imports
import requests # 替换 httpx
import typer import typer
# Local imports
from .config import settings, display_config from .config import settings, display_config
from .models import init_database
from .record import (
run_screen_recorder_once,
run_screen_recorder,
load_previous_hashes,
)
import time
import sys import sys
import subprocess import subprocess
import platform import platform
from .cmds.plugin import plugin_app, bind
from .cmds.library import lib_app, scan, reindex, watch from .cmds.plugin import plugin_app
from .cmds.library import lib_app
import psutil import psutil
import signal import signal
from tabulate import tabulate from tabulate import tabulate
@ -35,16 +35,16 @@ logging.basicConfig(
) )
# Optionally, you can set the logging level for specific libraries # Optionally, you can set the logging level for specific libraries
logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("requests").setLevel(logging.ERROR)
logging.getLogger("typer").setLevel(logging.ERROR) logging.getLogger("typer").setLevel(logging.ERROR)
def check_server_health(): def check_server_health():
"""Check if the server is running and healthy.""" """Check if the server is running and healthy."""
try: try:
response = httpx.get(f"{BASE_URL}/health", timeout=5) response = requests.get(f"{BASE_URL}/health", timeout=5)
return response.status_code == 200 return response.status_code == 200
except httpx.RequestError: except requests.RequestException:
return False return False
@ -55,13 +55,11 @@ def callback(ctx: typer.Context):
"scan", "scan",
"reindex", "reindex",
"watch", "watch",
"ls", "ls",
"create", "create",
"add-folder", "add-folder",
"show", "show",
"sync", "sync",
"bind", "bind",
"unbind", "unbind",
] ]
@ -79,9 +77,12 @@ app.add_typer(lib_app, name="lib", callback=callback)
@app.command() @app.command()
def serve(): def serve():
"""Run the server after initializing if necessary.""" """Run the server after initializing if necessary."""
from .models import init_database
db_success = init_database() db_success = init_database()
if db_success: if db_success:
from .server import run_server from .server import run_server
run_server() run_server()
else: else:
print("Server initialization failed. Unable to start the server.") print("Server initialization failed. Unable to start the server.")
@ -90,6 +91,8 @@ def serve():
@app.command() @app.command()
def init(): def init():
"""Initialize the database.""" """Initialize the database."""
from .models import init_database
db_success = init_database() db_success = init_database()
if db_success: if db_success:
print("Initialization completed successfully.") print("Initialization completed successfully.")
@ -102,7 +105,8 @@ def get_or_create_default_library():
Get the default library or create it if it doesn't exist. Get the default library or create it if it doesn't exist.
Ensure the library has at least one folder. Ensure the library has at least one folder.
""" """
response = httpx.get(f"{BASE_URL}/libraries") from .cmds.plugin import bind
response = requests.get(f"{BASE_URL}/libraries")
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return None return None
@ -114,7 +118,7 @@ def get_or_create_default_library():
if not default_library: if not default_library:
# Create the default library if it doesn't exist # Create the default library if it doesn't exist
response = httpx.post( response = requests.post(
f"{BASE_URL}/libraries", f"{BASE_URL}/libraries",
json={"name": settings.default_library, "folders": []}, json={"name": settings.default_library, "folders": []},
) )
@ -138,7 +142,7 @@ def get_or_create_default_library():
screenshots_dir.stat().st_mtime screenshots_dir.stat().st_mtime
).isoformat(), ).isoformat(),
} }
response = httpx.post( response = requests.post(
f"{BASE_URL}/libraries/{default_library['id']}/folders", f"{BASE_URL}/libraries/{default_library['id']}/folders",
json={"folders": [folder]}, json={"folders": [folder]},
) )
@ -162,13 +166,16 @@ def scan_default_library(
""" """
Scan the screenshots directory and add it to the library if empty. Scan the screenshots directory and add it to the library if empty.
""" """
from .cmds.library import scan
default_library = get_or_create_default_library() default_library = get_or_create_default_library()
if not default_library: if not default_library:
return return
# Scan the library
print(f"Scanning library: {default_library['name']}") print(f"Scanning library: {default_library['name']}")
scan(default_library["id"], path=path, plugins=plugins, folders=folders, force=force) scan(
default_library["id"], path=path, plugins=plugins, folders=folders, force=force
)
@app.command("reindex") @app.command("reindex")
@ -180,8 +187,10 @@ def reindex_default_library(
""" """
Reindex the default library for memos. Reindex the default library for memos.
""" """
from .cmds.library import reindex
# Get the default library # Get the default library
response = httpx.get(f"{BASE_URL}/libraries") response = requests.get(f"{BASE_URL}/libraries")
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return return
@ -209,6 +218,12 @@ def record(
""" """
Record screenshots of the screen. Record screenshots of the screen.
""" """
from .record import (
run_screen_recorder_once,
run_screen_recorder,
load_previous_hashes,
)
base_dir = ( base_dir = (
os.path.expanduser(base_dir) if base_dir else settings.resolved_screenshots_dir os.path.expanduser(base_dir) if base_dir else settings.resolved_screenshots_dir
) )
@ -242,6 +257,8 @@ def watch_default_library(
""" """
Watch the default library for file changes and sync automatically. Watch the default library for file changes and sync automatically.
""" """
from .cmds.library import watch
default_library = get_or_create_default_library() default_library = get_or_create_default_library()
if not default_library: if not default_library:
return return
@ -343,7 +360,7 @@ def generate_plist():
python_dir = os.path.dirname(get_python_path()) python_dir = os.path.dirname(get_python_path())
log_dir = memos_dir / "logs" log_dir = memos_dir / "logs"
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
plist_content = f"""<?xml version="1.0" encoding="UTF-8"?> plist_content = f"""<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
"http://www.apple.com/DTDs/PropertyList-1.0.dtd"> "http://www.apple.com/DTDs/PropertyList-1.0.dtd">

View File

@ -1,12 +1,8 @@
from typing import List from typing import List
from sentence_transformers import SentenceTransformer
import torch
import numpy as np import numpy as np
from modelscope import snapshot_download
from .config import settings from .config import settings
import logging import logging
import httpx import httpx
import asyncio
# Configure logger # Configure logger
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -18,6 +14,9 @@ device = None
def init_embedding_model(): def init_embedding_model():
import torch
from sentence_transformers import SentenceTransformer
global model, device global model, device
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -27,6 +26,7 @@ def init_embedding_model():
device = torch.device("cpu") device = torch.device("cpu")
if settings.embedding.use_modelscope: if settings.embedding.use_modelscope:
from modelscope import snapshot_download
model_dir = snapshot_download(settings.embedding.model) model_dir = snapshot_download(settings.embedding.model)
logger.info(f"Model downloaded from ModelScope to: {model_dir}") logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else: else:

View File

@ -475,8 +475,7 @@ def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
thread.start() thread.start()
thread.join() thread.join()
# Add event listeners for EntityModel
# Replace the old event listener with the new sync version
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync) event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
event.listen(EntityModel, "after_update", update_fts_and_vec_sync) event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
event.listen(EntityModel, "after_delete", delete_fts_and_vec) event.listen(EntityModel, "after_delete", delete_fts_and_vec)

View File

@ -47,6 +47,7 @@ dependencies = [
"mss", "mss",
"sqlite_vec", "sqlite_vec",
"watchdog", "watchdog",
"requests",
] ]
[project.urls] [project.urls]