refact: add lots of dynamic import to speed up cli

This commit is contained in:
arkohut 2024-11-06 13:13:59 +08:00
parent ea8301745f
commit 05c168142a
5 changed files with 108 additions and 78 deletions

View File

@ -1,35 +1,34 @@
# Standard library imports
import time
import math
import typer
import httpx
import re
import os
import threading
import asyncio
import logging
import threading
from tqdm import tqdm
import logging.config
from pathlib import Path
from tabulate import tabulate
from memos.config import settings
from magika import Magika
from datetime import datetime
from enum import Enum
from typing import List, Tuple
from functools import lru_cache
import re
import os
import time
import psutil
from collections import defaultdict, deque
# Third-party imports
import typer
import requests
from tqdm import tqdm
from tabulate import tabulate
import psutil
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from concurrent.futures import ThreadPoolExecutor
from memos.models import recreate_fts_and_vec_tables
# Local imports
from memos.config import settings
from memos.utils import get_image_metadata
from memos.schemas import MetadataSource
from memos.logging_config import LOGGING_CONFIG
import logging.config
logging.config.dictConfig(LOGGING_CONFIG)
@ -37,7 +36,7 @@ logger = logging.getLogger(__name__)
lib_app = typer.Typer()
file_detector = Magika()
file_detector = None
IS_THUMBNAIL = "is_thumbnail"
@ -57,8 +56,20 @@ def format_timestamp(timestamp):
return datetime.fromtimestamp(timestamp).replace(tzinfo=None).isoformat()
def init_file_detector():
"""Initialize the global file detector if not already initialized"""
global file_detector
if file_detector is None:
from magika import Magika
file_detector = Magika()
return file_detector
def get_file_type(file_path):
file_result = file_detector.identify_path(file_path)
"""Get file type using lazy-loaded detector"""
detector = init_file_detector()
file_result = detector.identify_path(file_path)
return file_result.output.ct_label, file_result.output.group
@ -86,7 +97,7 @@ def display_libraries(libraries):
@lib_app.command("ls")
def ls():
response = httpx.get(f"{BASE_URL}/libraries")
response = requests.get(f"{BASE_URL}/libraries")
libraries = response.json()
display_libraries(libraries)
@ -105,11 +116,10 @@ def add(name: str, folders: List[str]):
}
)
response = httpx.post(
f"{BASE_URL}/libraries",
json={"name": name, "folders": absolute_folders},
response = requests.post(
f"{BASE_URL}/libraries", json={"name": name, "folders": absolute_folders}
)
if 200 <= response.status_code < 300:
if response.ok:
print("Library created successfully")
else:
print(f"Failed to create library: {response.status_code} - {response.text}")
@ -129,7 +139,7 @@ def add_folder(library_id: int, folders: List[str]):
}
)
response = httpx.post(
response = requests.post(
f"{BASE_URL}/libraries/{library_id}/folders",
json={"folders": absolute_folders},
)
@ -143,7 +153,7 @@ def add_folder(library_id: int, folders: List[str]):
@lib_app.command("show")
def show(library_id: int):
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code == 200:
library = response.json()
display_libraries([library])
@ -164,7 +174,8 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
added_file_count = 0
scanned_files = set()
semaphore = asyncio.Semaphore(settings.batchsize)
async with httpx.AsyncClient(timeout=60) as client:
with requests.Session() as session:
tasks = []
for root, _, files in os.walk(folder_path):
with tqdm(total=len(files), desc=f"Scanning {root}", leave=True) as pbar:
@ -186,12 +197,12 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
batch = candidate_files[i : i + batching]
# Get batch of entities
get_response = await client.post(
get_response = session.post(
f"{BASE_URL}/libraries/{library_id}/entities/by-filepaths",
json=batch,
)
if get_response.status_code == 200:
if get_response.ok:
existing_entities = get_response.json()
else:
print(
@ -305,7 +316,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
):
tasks.append(
update_entity(
client,
session,
semaphore,
plugins,
new_entity,
@ -315,7 +326,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
elif not is_thumbnail: # Ignore thumbnails
tasks.append(
add_entity(
client, semaphore, library_id, plugins, new_entity
session, semaphore, library_id, plugins, new_entity
)
)
pbar.update(len(batch))
@ -372,7 +383,7 @@ def scan(
print("Error: You cannot specify both a path and folders at the same time.")
return
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200:
print(f"Failed to retrieve library: {response.status_code} - {response.text}")
return
@ -428,7 +439,7 @@ def scan(
total=total_entities, desc="Checking for deleted files", leave=True
) as pbar2:
while True:
existing_files_response = httpx.get(
existing_files_response = requests.get(
f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities",
params={"limit": limit, "offset": offset},
timeout=60,
@ -459,7 +470,7 @@ def scan(
and existing_file["filepath"] not in scanned_files
):
# File has been deleted
delete_response = httpx.delete(
delete_response = requests.delete(
f"{BASE_URL}/libraries/{library_id}/entities/{existing_file['id']}"
)
if 200 <= delete_response.status_code < 300:
@ -481,18 +492,18 @@ def scan(
async def add_entity(
client: httpx.AsyncClient,
client: requests.Session,
semaphore: asyncio.Semaphore,
library_id,
plugins,
new_entity,
) -> Tuple[FileStatus, bool, httpx.Response]:
) -> Tuple[FileStatus, bool, requests.Response]:
async with semaphore:
MAX_RETRIES = 3
RETRY_DELAY = 2.0
for attempt in range(MAX_RETRIES):
try:
post_response = await client.post(
post_response = client.post(
f"{BASE_URL}/libraries/{library_id}/entities",
json=new_entity,
params={"plugins": plugins} if plugins else {},
@ -518,18 +529,18 @@ async def add_entity(
async def update_entity(
client: httpx.AsyncClient,
client: requests.Session,
semaphore: asyncio.Semaphore,
plugins,
new_entity,
existing_entity,
) -> Tuple[FileStatus, bool, httpx.Response]:
) -> Tuple[FileStatus, bool, requests.Response]:
MAX_RETRIES = 3
RETRY_DELAY = 2.0
async with semaphore:
for attempt in range(MAX_RETRIES):
try:
update_response = await client.put(
update_response = client.put(
f"{BASE_URL}/entities/{existing_entity['id']}",
json=new_entity,
params={
@ -572,8 +583,10 @@ def reindex(
):
print(f"Reindexing library {library_id}")
from memos.models import recreate_fts_and_vec_tables
# Get the library
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200:
print(f"Failed to get library: {response.status_code} - {response.text}")
return
@ -594,7 +607,7 @@ def reindex(
recreate_fts_and_vec_tables()
print("FTS and vector tables have been recreated.")
with httpx.Client(timeout=60) as client:
with requests.Session() as client:
total_entities = 0
# Get total entity count for all folders
@ -657,7 +670,7 @@ def reindex(
async def check_and_index_entity(client, entity_id, entity_last_scan_at):
try:
index_response = await client.get(f"{BASE_URL}/entities/{entity_id}/index")
index_response = client.get(f"{BASE_URL}/entities/{entity_id}/index")
if index_response.status_code == 200:
index_data = index_response.json()
if index_data["last_scan_at"] is None:
@ -668,14 +681,14 @@ async def check_and_index_entity(client, entity_id, entity_last_scan_at):
if index_last_scan_at >= entity_last_scan_at:
return False # Index is up to date, no need to update
return True # Index doesn't exist or needs update
except httpx.HTTPStatusError as e:
except requests.HTTPStatusError as e:
if e.response.status_code == 404:
return True # Index doesn't exist, need to create
raise # Re-raise other HTTP errors
async def index_batch(client, entity_ids):
index_response = await client.post(
index_response = client.post(
f"{BASE_URL}/entities/batch-index",
json=entity_ids,
timeout=60,
@ -698,7 +711,7 @@ def sync(
Sync a specific file with the library.
"""
# 1. Get library by id and check if it exists
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200:
typer.echo(f"Error: Library with id {library_id} not found.")
raise typer.Exit(code=1)
@ -713,7 +726,7 @@ def sync(
raise typer.Exit(code=1)
# 2. Check if the file exists in the library
response = httpx.get(
response = requests.get(
f"{BASE_URL}/libraries/{library_id}/entities/by-filepath",
params={"filepath": str(file_path)},
)
@ -782,7 +795,7 @@ def sync(
!= new_entity["file_last_modified_at"]
or existing_entity["size"] != new_entity["size"]
):
update_response = httpx.put(
update_response = requests.put(
f"{BASE_URL}/entities/{existing_entity['id']}",
json=new_entity,
params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
@ -812,7 +825,7 @@ def sync(
# Create new entity
new_entity["folder_id"] = folder["id"]
create_response = httpx.post(
create_response = requests.post(
f"{BASE_URL}/libraries/{library_id}/entities",
json=new_entity,
params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
@ -1052,7 +1065,7 @@ def watch(
logger.info(f"Watching library {library_id} for changes...")
# Get the library
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200:
print(f"Error: Library with id {library_id} not found.")
raise typer.Exit(code=1)

View File

@ -1,24 +1,24 @@
# Standard library imports
import os
import logging
from pathlib import Path
from datetime import datetime, timedelta
from typing import List
import httpx
# Third-party imports
import requests # 替换 httpx
import typer
# Local imports
from .config import settings, display_config
from .models import init_database
from .record import (
run_screen_recorder_once,
run_screen_recorder,
load_previous_hashes,
)
import time
import sys
import subprocess
import platform
from .cmds.plugin import plugin_app, bind
from .cmds.library import lib_app, scan, reindex, watch
from .cmds.plugin import plugin_app
from .cmds.library import lib_app
import psutil
import signal
from tabulate import tabulate
@ -35,16 +35,16 @@ logging.basicConfig(
)
# Optionally, you can set the logging level for specific libraries
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("requests").setLevel(logging.ERROR)
logging.getLogger("typer").setLevel(logging.ERROR)
def check_server_health():
"""Check if the server is running and healthy."""
try:
response = httpx.get(f"{BASE_URL}/health", timeout=5)
response = requests.get(f"{BASE_URL}/health", timeout=5)
return response.status_code == 200
except httpx.RequestError:
except requests.RequestException:
return False
@ -55,13 +55,11 @@ def callback(ctx: typer.Context):
"scan",
"reindex",
"watch",
"ls",
"create",
"add-folder",
"show",
"sync",
"bind",
"unbind",
]
@ -79,9 +77,12 @@ app.add_typer(lib_app, name="lib", callback=callback)
@app.command()
def serve():
"""Run the server after initializing if necessary."""
from .models import init_database
db_success = init_database()
if db_success:
from .server import run_server
run_server()
else:
print("Server initialization failed. Unable to start the server.")
@ -90,6 +91,8 @@ def serve():
@app.command()
def init():
"""Initialize the database."""
from .models import init_database
db_success = init_database()
if db_success:
print("Initialization completed successfully.")
@ -102,7 +105,8 @@ def get_or_create_default_library():
Get the default library or create it if it doesn't exist.
Ensure the library has at least one folder.
"""
response = httpx.get(f"{BASE_URL}/libraries")
from .cmds.plugin import bind
response = requests.get(f"{BASE_URL}/libraries")
if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return None
@ -114,7 +118,7 @@ def get_or_create_default_library():
if not default_library:
# Create the default library if it doesn't exist
response = httpx.post(
response = requests.post(
f"{BASE_URL}/libraries",
json={"name": settings.default_library, "folders": []},
)
@ -138,7 +142,7 @@ def get_or_create_default_library():
screenshots_dir.stat().st_mtime
).isoformat(),
}
response = httpx.post(
response = requests.post(
f"{BASE_URL}/libraries/{default_library['id']}/folders",
json={"folders": [folder]},
)
@ -162,13 +166,16 @@ def scan_default_library(
"""
Scan the screenshots directory and add it to the library if empty.
"""
from .cmds.library import scan
default_library = get_or_create_default_library()
if not default_library:
return
# Scan the library
print(f"Scanning library: {default_library['name']}")
scan(default_library["id"], path=path, plugins=plugins, folders=folders, force=force)
scan(
default_library["id"], path=path, plugins=plugins, folders=folders, force=force
)
@app.command("reindex")
@ -180,8 +187,10 @@ def reindex_default_library(
"""
Reindex the default library for memos.
"""
from .cmds.library import reindex
# Get the default library
response = httpx.get(f"{BASE_URL}/libraries")
response = requests.get(f"{BASE_URL}/libraries")
if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return
@ -209,6 +218,12 @@ def record(
"""
Record screenshots of the screen.
"""
from .record import (
run_screen_recorder_once,
run_screen_recorder,
load_previous_hashes,
)
base_dir = (
os.path.expanduser(base_dir) if base_dir else settings.resolved_screenshots_dir
)
@ -242,6 +257,8 @@ def watch_default_library(
"""
Watch the default library for file changes and sync automatically.
"""
from .cmds.library import watch
default_library = get_or_create_default_library()
if not default_library:
return
@ -343,7 +360,7 @@ def generate_plist():
python_dir = os.path.dirname(get_python_path())
log_dir = memos_dir / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
plist_content = f"""<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">

View File

@ -1,12 +1,8 @@
from typing import List
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from modelscope import snapshot_download
from .config import settings
import logging
import httpx
import asyncio
# Configure logger
logging.basicConfig(level=logging.INFO)
@ -18,6 +14,9 @@ device = None
def init_embedding_model():
import torch
from sentence_transformers import SentenceTransformer
global model, device
if torch.cuda.is_available():
device = torch.device("cuda")
@ -27,6 +26,7 @@ def init_embedding_model():
device = torch.device("cpu")
if settings.embedding.use_modelscope:
from modelscope import snapshot_download
model_dir = snapshot_download(settings.embedding.model)
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:

View File

@ -475,8 +475,7 @@ def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
thread.start()
thread.join()
# Replace the old event listener with the new sync version
# Add event listeners for EntityModel
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
event.listen(EntityModel, "after_delete", delete_fts_and_vec)

View File

@ -47,6 +47,7 @@ dependencies = [
"mss",
"sqlite_vec",
"watchdog",
"requests",
]
[project.urls]