mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
refact: add lots of dynamic import to speed up cli
This commit is contained in:
parent
ea8301745f
commit
05c168142a
@ -1,35 +1,34 @@
|
|||||||
|
# Standard library imports
|
||||||
|
import time
|
||||||
import math
|
import math
|
||||||
import typer
|
import re
|
||||||
import httpx
|
import os
|
||||||
|
import threading
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import logging.config
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tabulate import tabulate
|
|
||||||
from memos.config import settings
|
|
||||||
from magika import Magika
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
# Third-party imports
|
||||||
|
import typer
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
from tabulate import tabulate
|
||||||
|
import psutil
|
||||||
from watchdog.observers import Observer
|
from watchdog.observers import Observer
|
||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from memos.models import recreate_fts_and_vec_tables
|
# Local imports
|
||||||
|
from memos.config import settings
|
||||||
from memos.utils import get_image_metadata
|
from memos.utils import get_image_metadata
|
||||||
from memos.schemas import MetadataSource
|
from memos.schemas import MetadataSource
|
||||||
from memos.logging_config import LOGGING_CONFIG
|
from memos.logging_config import LOGGING_CONFIG
|
||||||
import logging.config
|
|
||||||
|
|
||||||
|
|
||||||
logging.config.dictConfig(LOGGING_CONFIG)
|
logging.config.dictConfig(LOGGING_CONFIG)
|
||||||
@ -37,7 +36,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
lib_app = typer.Typer()
|
lib_app = typer.Typer()
|
||||||
|
|
||||||
file_detector = Magika()
|
file_detector = None
|
||||||
|
|
||||||
IS_THUMBNAIL = "is_thumbnail"
|
IS_THUMBNAIL = "is_thumbnail"
|
||||||
|
|
||||||
@ -57,8 +56,20 @@ def format_timestamp(timestamp):
|
|||||||
return datetime.fromtimestamp(timestamp).replace(tzinfo=None).isoformat()
|
return datetime.fromtimestamp(timestamp).replace(tzinfo=None).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def init_file_detector():
|
||||||
|
"""Initialize the global file detector if not already initialized"""
|
||||||
|
global file_detector
|
||||||
|
if file_detector is None:
|
||||||
|
from magika import Magika
|
||||||
|
|
||||||
|
file_detector = Magika()
|
||||||
|
return file_detector
|
||||||
|
|
||||||
|
|
||||||
def get_file_type(file_path):
|
def get_file_type(file_path):
|
||||||
file_result = file_detector.identify_path(file_path)
|
"""Get file type using lazy-loaded detector"""
|
||||||
|
detector = init_file_detector()
|
||||||
|
file_result = detector.identify_path(file_path)
|
||||||
return file_result.output.ct_label, file_result.output.group
|
return file_result.output.ct_label, file_result.output.group
|
||||||
|
|
||||||
|
|
||||||
@ -86,7 +97,7 @@ def display_libraries(libraries):
|
|||||||
|
|
||||||
@lib_app.command("ls")
|
@lib_app.command("ls")
|
||||||
def ls():
|
def ls():
|
||||||
response = httpx.get(f"{BASE_URL}/libraries")
|
response = requests.get(f"{BASE_URL}/libraries")
|
||||||
libraries = response.json()
|
libraries = response.json()
|
||||||
display_libraries(libraries)
|
display_libraries(libraries)
|
||||||
|
|
||||||
@ -105,11 +116,10 @@ def add(name: str, folders: List[str]):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
response = httpx.post(
|
response = requests.post(
|
||||||
f"{BASE_URL}/libraries",
|
f"{BASE_URL}/libraries", json={"name": name, "folders": absolute_folders}
|
||||||
json={"name": name, "folders": absolute_folders},
|
|
||||||
)
|
)
|
||||||
if 200 <= response.status_code < 300:
|
if response.ok:
|
||||||
print("Library created successfully")
|
print("Library created successfully")
|
||||||
else:
|
else:
|
||||||
print(f"Failed to create library: {response.status_code} - {response.text}")
|
print(f"Failed to create library: {response.status_code} - {response.text}")
|
||||||
@ -129,7 +139,7 @@ def add_folder(library_id: int, folders: List[str]):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
response = httpx.post(
|
response = requests.post(
|
||||||
f"{BASE_URL}/libraries/{library_id}/folders",
|
f"{BASE_URL}/libraries/{library_id}/folders",
|
||||||
json={"folders": absolute_folders},
|
json={"folders": absolute_folders},
|
||||||
)
|
)
|
||||||
@ -143,7 +153,7 @@ def add_folder(library_id: int, folders: List[str]):
|
|||||||
|
|
||||||
@lib_app.command("show")
|
@lib_app.command("show")
|
||||||
def show(library_id: int):
|
def show(library_id: int):
|
||||||
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
|
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
library = response.json()
|
library = response.json()
|
||||||
display_libraries([library])
|
display_libraries([library])
|
||||||
@ -164,7 +174,8 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
|
|||||||
added_file_count = 0
|
added_file_count = 0
|
||||||
scanned_files = set()
|
scanned_files = set()
|
||||||
semaphore = asyncio.Semaphore(settings.batchsize)
|
semaphore = asyncio.Semaphore(settings.batchsize)
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
|
||||||
|
with requests.Session() as session:
|
||||||
tasks = []
|
tasks = []
|
||||||
for root, _, files in os.walk(folder_path):
|
for root, _, files in os.walk(folder_path):
|
||||||
with tqdm(total=len(files), desc=f"Scanning {root}", leave=True) as pbar:
|
with tqdm(total=len(files), desc=f"Scanning {root}", leave=True) as pbar:
|
||||||
@ -186,12 +197,12 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
|
|||||||
batch = candidate_files[i : i + batching]
|
batch = candidate_files[i : i + batching]
|
||||||
|
|
||||||
# Get batch of entities
|
# Get batch of entities
|
||||||
get_response = await client.post(
|
get_response = session.post(
|
||||||
f"{BASE_URL}/libraries/{library_id}/entities/by-filepaths",
|
f"{BASE_URL}/libraries/{library_id}/entities/by-filepaths",
|
||||||
json=batch,
|
json=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_response.status_code == 200:
|
if get_response.ok:
|
||||||
existing_entities = get_response.json()
|
existing_entities = get_response.json()
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
@ -305,7 +316,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
|
|||||||
):
|
):
|
||||||
tasks.append(
|
tasks.append(
|
||||||
update_entity(
|
update_entity(
|
||||||
client,
|
session,
|
||||||
semaphore,
|
semaphore,
|
||||||
plugins,
|
plugins,
|
||||||
new_entity,
|
new_entity,
|
||||||
@ -315,7 +326,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
|
|||||||
elif not is_thumbnail: # Ignore thumbnails
|
elif not is_thumbnail: # Ignore thumbnails
|
||||||
tasks.append(
|
tasks.append(
|
||||||
add_entity(
|
add_entity(
|
||||||
client, semaphore, library_id, plugins, new_entity
|
session, semaphore, library_id, plugins, new_entity
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
pbar.update(len(batch))
|
pbar.update(len(batch))
|
||||||
@ -372,7 +383,7 @@ def scan(
|
|||||||
print("Error: You cannot specify both a path and folders at the same time.")
|
print("Error: You cannot specify both a path and folders at the same time.")
|
||||||
return
|
return
|
||||||
|
|
||||||
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
|
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print(f"Failed to retrieve library: {response.status_code} - {response.text}")
|
print(f"Failed to retrieve library: {response.status_code} - {response.text}")
|
||||||
return
|
return
|
||||||
@ -428,7 +439,7 @@ def scan(
|
|||||||
total=total_entities, desc="Checking for deleted files", leave=True
|
total=total_entities, desc="Checking for deleted files", leave=True
|
||||||
) as pbar2:
|
) as pbar2:
|
||||||
while True:
|
while True:
|
||||||
existing_files_response = httpx.get(
|
existing_files_response = requests.get(
|
||||||
f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities",
|
f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities",
|
||||||
params={"limit": limit, "offset": offset},
|
params={"limit": limit, "offset": offset},
|
||||||
timeout=60,
|
timeout=60,
|
||||||
@ -459,7 +470,7 @@ def scan(
|
|||||||
and existing_file["filepath"] not in scanned_files
|
and existing_file["filepath"] not in scanned_files
|
||||||
):
|
):
|
||||||
# File has been deleted
|
# File has been deleted
|
||||||
delete_response = httpx.delete(
|
delete_response = requests.delete(
|
||||||
f"{BASE_URL}/libraries/{library_id}/entities/{existing_file['id']}"
|
f"{BASE_URL}/libraries/{library_id}/entities/{existing_file['id']}"
|
||||||
)
|
)
|
||||||
if 200 <= delete_response.status_code < 300:
|
if 200 <= delete_response.status_code < 300:
|
||||||
@ -481,18 +492,18 @@ def scan(
|
|||||||
|
|
||||||
|
|
||||||
async def add_entity(
|
async def add_entity(
|
||||||
client: httpx.AsyncClient,
|
client: requests.Session,
|
||||||
semaphore: asyncio.Semaphore,
|
semaphore: asyncio.Semaphore,
|
||||||
library_id,
|
library_id,
|
||||||
plugins,
|
plugins,
|
||||||
new_entity,
|
new_entity,
|
||||||
) -> Tuple[FileStatus, bool, httpx.Response]:
|
) -> Tuple[FileStatus, bool, requests.Response]:
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
MAX_RETRIES = 3
|
MAX_RETRIES = 3
|
||||||
RETRY_DELAY = 2.0
|
RETRY_DELAY = 2.0
|
||||||
for attempt in range(MAX_RETRIES):
|
for attempt in range(MAX_RETRIES):
|
||||||
try:
|
try:
|
||||||
post_response = await client.post(
|
post_response = client.post(
|
||||||
f"{BASE_URL}/libraries/{library_id}/entities",
|
f"{BASE_URL}/libraries/{library_id}/entities",
|
||||||
json=new_entity,
|
json=new_entity,
|
||||||
params={"plugins": plugins} if plugins else {},
|
params={"plugins": plugins} if plugins else {},
|
||||||
@ -518,18 +529,18 @@ async def add_entity(
|
|||||||
|
|
||||||
|
|
||||||
async def update_entity(
|
async def update_entity(
|
||||||
client: httpx.AsyncClient,
|
client: requests.Session,
|
||||||
semaphore: asyncio.Semaphore,
|
semaphore: asyncio.Semaphore,
|
||||||
plugins,
|
plugins,
|
||||||
new_entity,
|
new_entity,
|
||||||
existing_entity,
|
existing_entity,
|
||||||
) -> Tuple[FileStatus, bool, httpx.Response]:
|
) -> Tuple[FileStatus, bool, requests.Response]:
|
||||||
MAX_RETRIES = 3
|
MAX_RETRIES = 3
|
||||||
RETRY_DELAY = 2.0
|
RETRY_DELAY = 2.0
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
for attempt in range(MAX_RETRIES):
|
for attempt in range(MAX_RETRIES):
|
||||||
try:
|
try:
|
||||||
update_response = await client.put(
|
update_response = client.put(
|
||||||
f"{BASE_URL}/entities/{existing_entity['id']}",
|
f"{BASE_URL}/entities/{existing_entity['id']}",
|
||||||
json=new_entity,
|
json=new_entity,
|
||||||
params={
|
params={
|
||||||
@ -572,8 +583,10 @@ def reindex(
|
|||||||
):
|
):
|
||||||
print(f"Reindexing library {library_id}")
|
print(f"Reindexing library {library_id}")
|
||||||
|
|
||||||
|
from memos.models import recreate_fts_and_vec_tables
|
||||||
|
|
||||||
# Get the library
|
# Get the library
|
||||||
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
|
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print(f"Failed to get library: {response.status_code} - {response.text}")
|
print(f"Failed to get library: {response.status_code} - {response.text}")
|
||||||
return
|
return
|
||||||
@ -594,7 +607,7 @@ def reindex(
|
|||||||
recreate_fts_and_vec_tables()
|
recreate_fts_and_vec_tables()
|
||||||
print("FTS and vector tables have been recreated.")
|
print("FTS and vector tables have been recreated.")
|
||||||
|
|
||||||
with httpx.Client(timeout=60) as client:
|
with requests.Session() as client:
|
||||||
total_entities = 0
|
total_entities = 0
|
||||||
|
|
||||||
# Get total entity count for all folders
|
# Get total entity count for all folders
|
||||||
@ -657,7 +670,7 @@ def reindex(
|
|||||||
|
|
||||||
async def check_and_index_entity(client, entity_id, entity_last_scan_at):
|
async def check_and_index_entity(client, entity_id, entity_last_scan_at):
|
||||||
try:
|
try:
|
||||||
index_response = await client.get(f"{BASE_URL}/entities/{entity_id}/index")
|
index_response = client.get(f"{BASE_URL}/entities/{entity_id}/index")
|
||||||
if index_response.status_code == 200:
|
if index_response.status_code == 200:
|
||||||
index_data = index_response.json()
|
index_data = index_response.json()
|
||||||
if index_data["last_scan_at"] is None:
|
if index_data["last_scan_at"] is None:
|
||||||
@ -668,14 +681,14 @@ async def check_and_index_entity(client, entity_id, entity_last_scan_at):
|
|||||||
if index_last_scan_at >= entity_last_scan_at:
|
if index_last_scan_at >= entity_last_scan_at:
|
||||||
return False # Index is up to date, no need to update
|
return False # Index is up to date, no need to update
|
||||||
return True # Index doesn't exist or needs update
|
return True # Index doesn't exist or needs update
|
||||||
except httpx.HTTPStatusError as e:
|
except requests.HTTPStatusError as e:
|
||||||
if e.response.status_code == 404:
|
if e.response.status_code == 404:
|
||||||
return True # Index doesn't exist, need to create
|
return True # Index doesn't exist, need to create
|
||||||
raise # Re-raise other HTTP errors
|
raise # Re-raise other HTTP errors
|
||||||
|
|
||||||
|
|
||||||
async def index_batch(client, entity_ids):
|
async def index_batch(client, entity_ids):
|
||||||
index_response = await client.post(
|
index_response = client.post(
|
||||||
f"{BASE_URL}/entities/batch-index",
|
f"{BASE_URL}/entities/batch-index",
|
||||||
json=entity_ids,
|
json=entity_ids,
|
||||||
timeout=60,
|
timeout=60,
|
||||||
@ -698,7 +711,7 @@ def sync(
|
|||||||
Sync a specific file with the library.
|
Sync a specific file with the library.
|
||||||
"""
|
"""
|
||||||
# 1. Get library by id and check if it exists
|
# 1. Get library by id and check if it exists
|
||||||
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
|
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
typer.echo(f"Error: Library with id {library_id} not found.")
|
typer.echo(f"Error: Library with id {library_id} not found.")
|
||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
@ -713,7 +726,7 @@ def sync(
|
|||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
# 2. Check if the file exists in the library
|
# 2. Check if the file exists in the library
|
||||||
response = httpx.get(
|
response = requests.get(
|
||||||
f"{BASE_URL}/libraries/{library_id}/entities/by-filepath",
|
f"{BASE_URL}/libraries/{library_id}/entities/by-filepath",
|
||||||
params={"filepath": str(file_path)},
|
params={"filepath": str(file_path)},
|
||||||
)
|
)
|
||||||
@ -782,7 +795,7 @@ def sync(
|
|||||||
!= new_entity["file_last_modified_at"]
|
!= new_entity["file_last_modified_at"]
|
||||||
or existing_entity["size"] != new_entity["size"]
|
or existing_entity["size"] != new_entity["size"]
|
||||||
):
|
):
|
||||||
update_response = httpx.put(
|
update_response = requests.put(
|
||||||
f"{BASE_URL}/entities/{existing_entity['id']}",
|
f"{BASE_URL}/entities/{existing_entity['id']}",
|
||||||
json=new_entity,
|
json=new_entity,
|
||||||
params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
|
params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
|
||||||
@ -812,7 +825,7 @@ def sync(
|
|||||||
# Create new entity
|
# Create new entity
|
||||||
new_entity["folder_id"] = folder["id"]
|
new_entity["folder_id"] = folder["id"]
|
||||||
|
|
||||||
create_response = httpx.post(
|
create_response = requests.post(
|
||||||
f"{BASE_URL}/libraries/{library_id}/entities",
|
f"{BASE_URL}/libraries/{library_id}/entities",
|
||||||
json=new_entity,
|
json=new_entity,
|
||||||
params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
|
params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
|
||||||
@ -1052,7 +1065,7 @@ def watch(
|
|||||||
logger.info(f"Watching library {library_id} for changes...")
|
logger.info(f"Watching library {library_id} for changes...")
|
||||||
|
|
||||||
# Get the library
|
# Get the library
|
||||||
response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
|
response = requests.get(f"{BASE_URL}/libraries/{library_id}")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print(f"Error: Library with id {library_id} not found.")
|
print(f"Error: Library with id {library_id} not found.")
|
||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
|
@ -1,24 +1,24 @@
|
|||||||
|
# Standard library imports
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import httpx
|
# Third-party imports
|
||||||
|
import requests # 替换 httpx
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
|
# Local imports
|
||||||
from .config import settings, display_config
|
from .config import settings, display_config
|
||||||
from .models import init_database
|
|
||||||
from .record import (
|
|
||||||
run_screen_recorder_once,
|
|
||||||
run_screen_recorder,
|
|
||||||
load_previous_hashes,
|
|
||||||
)
|
|
||||||
import time
|
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
import platform
|
import platform
|
||||||
from .cmds.plugin import plugin_app, bind
|
|
||||||
from .cmds.library import lib_app, scan, reindex, watch
|
from .cmds.plugin import plugin_app
|
||||||
|
from .cmds.library import lib_app
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import signal
|
import signal
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
@ -35,16 +35,16 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Optionally, you can set the logging level for specific libraries
|
# Optionally, you can set the logging level for specific libraries
|
||||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
logging.getLogger("requests").setLevel(logging.ERROR)
|
||||||
logging.getLogger("typer").setLevel(logging.ERROR)
|
logging.getLogger("typer").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
def check_server_health():
|
def check_server_health():
|
||||||
"""Check if the server is running and healthy."""
|
"""Check if the server is running and healthy."""
|
||||||
try:
|
try:
|
||||||
response = httpx.get(f"{BASE_URL}/health", timeout=5)
|
response = requests.get(f"{BASE_URL}/health", timeout=5)
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except httpx.RequestError:
|
except requests.RequestException:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -55,13 +55,11 @@ def callback(ctx: typer.Context):
|
|||||||
"scan",
|
"scan",
|
||||||
"reindex",
|
"reindex",
|
||||||
"watch",
|
"watch",
|
||||||
|
|
||||||
"ls",
|
"ls",
|
||||||
"create",
|
"create",
|
||||||
"add-folder",
|
"add-folder",
|
||||||
"show",
|
"show",
|
||||||
"sync",
|
"sync",
|
||||||
|
|
||||||
"bind",
|
"bind",
|
||||||
"unbind",
|
"unbind",
|
||||||
]
|
]
|
||||||
@ -79,9 +77,12 @@ app.add_typer(lib_app, name="lib", callback=callback)
|
|||||||
@app.command()
|
@app.command()
|
||||||
def serve():
|
def serve():
|
||||||
"""Run the server after initializing if necessary."""
|
"""Run the server after initializing if necessary."""
|
||||||
|
from .models import init_database
|
||||||
|
|
||||||
db_success = init_database()
|
db_success = init_database()
|
||||||
if db_success:
|
if db_success:
|
||||||
from .server import run_server
|
from .server import run_server
|
||||||
|
|
||||||
run_server()
|
run_server()
|
||||||
else:
|
else:
|
||||||
print("Server initialization failed. Unable to start the server.")
|
print("Server initialization failed. Unable to start the server.")
|
||||||
@ -90,6 +91,8 @@ def serve():
|
|||||||
@app.command()
|
@app.command()
|
||||||
def init():
|
def init():
|
||||||
"""Initialize the database."""
|
"""Initialize the database."""
|
||||||
|
from .models import init_database
|
||||||
|
|
||||||
db_success = init_database()
|
db_success = init_database()
|
||||||
if db_success:
|
if db_success:
|
||||||
print("Initialization completed successfully.")
|
print("Initialization completed successfully.")
|
||||||
@ -102,7 +105,8 @@ def get_or_create_default_library():
|
|||||||
Get the default library or create it if it doesn't exist.
|
Get the default library or create it if it doesn't exist.
|
||||||
Ensure the library has at least one folder.
|
Ensure the library has at least one folder.
|
||||||
"""
|
"""
|
||||||
response = httpx.get(f"{BASE_URL}/libraries")
|
from .cmds.plugin import bind
|
||||||
|
response = requests.get(f"{BASE_URL}/libraries")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
|
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
|
||||||
return None
|
return None
|
||||||
@ -114,7 +118,7 @@ def get_or_create_default_library():
|
|||||||
|
|
||||||
if not default_library:
|
if not default_library:
|
||||||
# Create the default library if it doesn't exist
|
# Create the default library if it doesn't exist
|
||||||
response = httpx.post(
|
response = requests.post(
|
||||||
f"{BASE_URL}/libraries",
|
f"{BASE_URL}/libraries",
|
||||||
json={"name": settings.default_library, "folders": []},
|
json={"name": settings.default_library, "folders": []},
|
||||||
)
|
)
|
||||||
@ -138,7 +142,7 @@ def get_or_create_default_library():
|
|||||||
screenshots_dir.stat().st_mtime
|
screenshots_dir.stat().st_mtime
|
||||||
).isoformat(),
|
).isoformat(),
|
||||||
}
|
}
|
||||||
response = httpx.post(
|
response = requests.post(
|
||||||
f"{BASE_URL}/libraries/{default_library['id']}/folders",
|
f"{BASE_URL}/libraries/{default_library['id']}/folders",
|
||||||
json={"folders": [folder]},
|
json={"folders": [folder]},
|
||||||
)
|
)
|
||||||
@ -162,13 +166,16 @@ def scan_default_library(
|
|||||||
"""
|
"""
|
||||||
Scan the screenshots directory and add it to the library if empty.
|
Scan the screenshots directory and add it to the library if empty.
|
||||||
"""
|
"""
|
||||||
|
from .cmds.library import scan
|
||||||
|
|
||||||
default_library = get_or_create_default_library()
|
default_library = get_or_create_default_library()
|
||||||
if not default_library:
|
if not default_library:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Scan the library
|
|
||||||
print(f"Scanning library: {default_library['name']}")
|
print(f"Scanning library: {default_library['name']}")
|
||||||
scan(default_library["id"], path=path, plugins=plugins, folders=folders, force=force)
|
scan(
|
||||||
|
default_library["id"], path=path, plugins=plugins, folders=folders, force=force
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.command("reindex")
|
@app.command("reindex")
|
||||||
@ -180,8 +187,10 @@ def reindex_default_library(
|
|||||||
"""
|
"""
|
||||||
Reindex the default library for memos.
|
Reindex the default library for memos.
|
||||||
"""
|
"""
|
||||||
|
from .cmds.library import reindex
|
||||||
|
|
||||||
# Get the default library
|
# Get the default library
|
||||||
response = httpx.get(f"{BASE_URL}/libraries")
|
response = requests.get(f"{BASE_URL}/libraries")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
|
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
|
||||||
return
|
return
|
||||||
@ -209,6 +218,12 @@ def record(
|
|||||||
"""
|
"""
|
||||||
Record screenshots of the screen.
|
Record screenshots of the screen.
|
||||||
"""
|
"""
|
||||||
|
from .record import (
|
||||||
|
run_screen_recorder_once,
|
||||||
|
run_screen_recorder,
|
||||||
|
load_previous_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
base_dir = (
|
base_dir = (
|
||||||
os.path.expanduser(base_dir) if base_dir else settings.resolved_screenshots_dir
|
os.path.expanduser(base_dir) if base_dir else settings.resolved_screenshots_dir
|
||||||
)
|
)
|
||||||
@ -242,6 +257,8 @@ def watch_default_library(
|
|||||||
"""
|
"""
|
||||||
Watch the default library for file changes and sync automatically.
|
Watch the default library for file changes and sync automatically.
|
||||||
"""
|
"""
|
||||||
|
from .cmds.library import watch
|
||||||
|
|
||||||
default_library = get_or_create_default_library()
|
default_library = get_or_create_default_library()
|
||||||
if not default_library:
|
if not default_library:
|
||||||
return
|
return
|
||||||
@ -343,7 +360,7 @@ def generate_plist():
|
|||||||
python_dir = os.path.dirname(get_python_path())
|
python_dir = os.path.dirname(get_python_path())
|
||||||
log_dir = memos_dir / "logs"
|
log_dir = memos_dir / "logs"
|
||||||
log_dir.mkdir(parents=True, exist_ok=True)
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
plist_content = f"""<?xml version="1.0" encoding="UTF-8"?>
|
plist_content = f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
|
||||||
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
|
@ -1,12 +1,8 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from modelscope import snapshot_download
|
|
||||||
from .config import settings
|
from .config import settings
|
||||||
import logging
|
import logging
|
||||||
import httpx
|
import httpx
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -18,6 +14,9 @@ device = None
|
|||||||
|
|
||||||
|
|
||||||
def init_embedding_model():
|
def init_embedding_model():
|
||||||
|
import torch
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
global model, device
|
global model, device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -27,6 +26,7 @@ def init_embedding_model():
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
if settings.embedding.use_modelscope:
|
if settings.embedding.use_modelscope:
|
||||||
|
from modelscope import snapshot_download
|
||||||
model_dir = snapshot_download(settings.embedding.model)
|
model_dir = snapshot_download(settings.embedding.model)
|
||||||
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
||||||
else:
|
else:
|
||||||
|
@ -475,8 +475,7 @@ def update_fts_and_vec_sync(mapper, connection, entity: EntityModel):
|
|||||||
thread.start()
|
thread.start()
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
|
# Add event listeners for EntityModel
|
||||||
# Replace the old event listener with the new sync version
|
|
||||||
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
|
event.listen(EntityModel, "after_insert", update_fts_and_vec_sync)
|
||||||
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
|
event.listen(EntityModel, "after_update", update_fts_and_vec_sync)
|
||||||
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
|
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
|
@ -47,6 +47,7 @@ dependencies = [
|
|||||||
"mss",
|
"mss",
|
||||||
"sqlite_vec",
|
"sqlite_vec",
|
||||||
"watchdog",
|
"watchdog",
|
||||||
|
"requests",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user