refact: use httpx back

This commit is contained in:
arkohut 2024-11-06 15:29:49 +08:00
parent 9af3c502a9
commit 5dad1675f0
3 changed files with 36 additions and 37 deletions

View File

@ -16,7 +16,7 @@ from collections import defaultdict, deque
# Third-party imports # Third-party imports
import typer import typer
import requests import httpx
from tqdm import tqdm from tqdm import tqdm
from tabulate import tabulate from tabulate import tabulate
import psutil import psutil
@ -97,7 +97,7 @@ def display_libraries(libraries):
@lib_app.command("ls") @lib_app.command("ls")
def ls(): def ls():
response = requests.get(f"{BASE_URL}/libraries") response = httpx.get(f"{BASE_URL}/libraries")
libraries = response.json() libraries = response.json()
display_libraries(libraries) display_libraries(libraries)
@ -116,10 +116,10 @@ def add(name: str, folders: List[str]):
} }
) )
response = requests.post( response = httpx.post(
f"{BASE_URL}/libraries", json={"name": name, "folders": absolute_folders} f"{BASE_URL}/libraries", json={"name": name, "folders": absolute_folders}
) )
if response.ok: if 200 <= response.status_code < 300:
print("Library created successfully") print("Library created successfully")
else: else:
print(f"Failed to create library: {response.status_code} - {response.text}") print(f"Failed to create library: {response.status_code} - {response.text}")
@ -139,7 +139,7 @@ def add_folder(library_id: int, folders: List[str]):
} }
) )
response = requests.post( response = httpx.post(
f"{BASE_URL}/libraries/{library_id}/folders", f"{BASE_URL}/libraries/{library_id}/folders",
json={"folders": absolute_folders}, json={"folders": absolute_folders},
) )
@ -153,7 +153,7 @@ def add_folder(library_id: int, folders: List[str]):
@lib_app.command("show") @lib_app.command("show")
def show(library_id: int): def show(library_id: int):
response = requests.get(f"{BASE_URL}/libraries/{library_id}") response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code == 200: if response.status_code == 200:
library = response.json() library = response.json()
display_libraries([library]) display_libraries([library])
@ -175,7 +175,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
scanned_files = set() scanned_files = set()
semaphore = asyncio.Semaphore(settings.batchsize) semaphore = asyncio.Semaphore(settings.batchsize)
with requests.Session() as session: async with httpx.AsyncClient(timeout=60) as client:
tasks = [] tasks = []
for root, _, files in os.walk(folder_path): for root, _, files in os.walk(folder_path):
with tqdm(total=len(files), desc=f"Scanning {root}", leave=True) as pbar: with tqdm(total=len(files), desc=f"Scanning {root}", leave=True) as pbar:
@ -197,12 +197,12 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
batch = candidate_files[i : i + batching] batch = candidate_files[i : i + batching]
# Get batch of entities # Get batch of entities
get_response = session.post( get_response = await client.post(
f"{BASE_URL}/libraries/{library_id}/entities/by-filepaths", f"{BASE_URL}/libraries/{library_id}/entities/by-filepaths",
json=batch, json=batch,
) )
if get_response.ok: if get_response.status_code == 200:
existing_entities = get_response.json() existing_entities = get_response.json()
else: else:
print( print(
@ -316,7 +316,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
): ):
tasks.append( tasks.append(
update_entity( update_entity(
session, client,
semaphore, semaphore,
plugins, plugins,
new_entity, new_entity,
@ -326,7 +326,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
elif not is_thumbnail: # Ignore thumbnails elif not is_thumbnail: # Ignore thumbnails
tasks.append( tasks.append(
add_entity( add_entity(
session, semaphore, library_id, plugins, new_entity client, semaphore, library_id, plugins, new_entity
) )
) )
pbar.update(len(batch)) pbar.update(len(batch))
@ -383,7 +383,7 @@ def scan(
print("Error: You cannot specify both a path and folders at the same time.") print("Error: You cannot specify both a path and folders at the same time.")
return return
response = requests.get(f"{BASE_URL}/libraries/{library_id}") response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to retrieve library: {response.status_code} - {response.text}") print(f"Failed to retrieve library: {response.status_code} - {response.text}")
return return
@ -439,7 +439,7 @@ def scan(
total=total_entities, desc="Checking for deleted files", leave=True total=total_entities, desc="Checking for deleted files", leave=True
) as pbar2: ) as pbar2:
while True: while True:
existing_files_response = requests.get( existing_files_response = httpx.get(
f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities", f"{BASE_URL}/libraries/{library_id}/folders/{folder['id']}/entities",
params={"limit": limit, "offset": offset}, params={"limit": limit, "offset": offset},
timeout=60, timeout=60,
@ -470,7 +470,7 @@ def scan(
and existing_file["filepath"] not in scanned_files and existing_file["filepath"] not in scanned_files
): ):
# File has been deleted # File has been deleted
delete_response = requests.delete( delete_response = httpx.delete(
f"{BASE_URL}/libraries/{library_id}/entities/{existing_file['id']}" f"{BASE_URL}/libraries/{library_id}/entities/{existing_file['id']}"
) )
if 200 <= delete_response.status_code < 300: if 200 <= delete_response.status_code < 300:
@ -492,18 +492,18 @@ def scan(
async def add_entity( async def add_entity(
client: requests.Session, client: httpx.AsyncClient,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
library_id, library_id,
plugins, plugins,
new_entity, new_entity,
) -> Tuple[FileStatus, bool, requests.Response]: ) -> Tuple[FileStatus, bool, httpx.Response]:
async with semaphore: async with semaphore:
MAX_RETRIES = 3 MAX_RETRIES = 3
RETRY_DELAY = 2.0 RETRY_DELAY = 2.0
for attempt in range(MAX_RETRIES): for attempt in range(MAX_RETRIES):
try: try:
post_response = client.post( post_response = await client.post(
f"{BASE_URL}/libraries/{library_id}/entities", f"{BASE_URL}/libraries/{library_id}/entities",
json=new_entity, json=new_entity,
params={"plugins": plugins} if plugins else {}, params={"plugins": plugins} if plugins else {},
@ -529,18 +529,18 @@ async def add_entity(
async def update_entity( async def update_entity(
client: requests.Session, client: httpx.AsyncClient,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
plugins, plugins,
new_entity, new_entity,
existing_entity, existing_entity,
) -> Tuple[FileStatus, bool, requests.Response]: ) -> Tuple[FileStatus, bool, httpx.Response]:
MAX_RETRIES = 3 MAX_RETRIES = 3
RETRY_DELAY = 2.0 RETRY_DELAY = 2.0
async with semaphore: async with semaphore:
for attempt in range(MAX_RETRIES): for attempt in range(MAX_RETRIES):
try: try:
update_response = client.put( update_response = await client.put(
f"{BASE_URL}/entities/{existing_entity['id']}", f"{BASE_URL}/entities/{existing_entity['id']}",
json=new_entity, json=new_entity,
params={ params={
@ -586,7 +586,7 @@ def reindex(
from memos.models import recreate_fts_and_vec_tables from memos.models import recreate_fts_and_vec_tables
# Get the library # Get the library
response = requests.get(f"{BASE_URL}/libraries/{library_id}") response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to get library: {response.status_code} - {response.text}") print(f"Failed to get library: {response.status_code} - {response.text}")
return return
@ -607,7 +607,7 @@ def reindex(
recreate_fts_and_vec_tables() recreate_fts_and_vec_tables()
print("FTS and vector tables have been recreated.") print("FTS and vector tables have been recreated.")
with requests.Session() as client: with httpx.Session() as client:
total_entities = 0 total_entities = 0
# Get total entity count for all folders # Get total entity count for all folders
@ -681,7 +681,7 @@ async def check_and_index_entity(client, entity_id, entity_last_scan_at):
if index_last_scan_at >= entity_last_scan_at: if index_last_scan_at >= entity_last_scan_at:
return False # Index is up to date, no need to update return False # Index is up to date, no need to update
return True # Index doesn't exist or needs update return True # Index doesn't exist or needs update
except requests.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if e.response.status_code == 404: if e.response.status_code == 404:
return True # Index doesn't exist, need to create return True # Index doesn't exist, need to create
raise # Re-raise other HTTP errors raise # Re-raise other HTTP errors
@ -711,7 +711,7 @@ def sync(
Sync a specific file with the library. Sync a specific file with the library.
""" """
# 1. Get library by id and check if it exists # 1. Get library by id and check if it exists
response = requests.get(f"{BASE_URL}/libraries/{library_id}") response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200: if response.status_code != 200:
typer.echo(f"Error: Library with id {library_id} not found.") typer.echo(f"Error: Library with id {library_id} not found.")
raise typer.Exit(code=1) raise typer.Exit(code=1)
@ -726,7 +726,7 @@ def sync(
raise typer.Exit(code=1) raise typer.Exit(code=1)
# 2. Check if the file exists in the library # 2. Check if the file exists in the library
response = requests.get( response = httpx.get(
f"{BASE_URL}/libraries/{library_id}/entities/by-filepath", f"{BASE_URL}/libraries/{library_id}/entities/by-filepath",
params={"filepath": str(file_path)}, params={"filepath": str(file_path)},
) )
@ -795,7 +795,7 @@ def sync(
!= new_entity["file_last_modified_at"] != new_entity["file_last_modified_at"]
or existing_entity["size"] != new_entity["size"] or existing_entity["size"] != new_entity["size"]
): ):
update_response = requests.put( update_response = httpx.put(
f"{BASE_URL}/entities/{existing_entity['id']}", f"{BASE_URL}/entities/{existing_entity['id']}",
json=new_entity, json=new_entity,
params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
@ -825,7 +825,7 @@ def sync(
# Create new entity # Create new entity
new_entity["folder_id"] = folder["id"] new_entity["folder_id"] = folder["id"]
create_response = requests.post( create_response = httpx.post(
f"{BASE_URL}/libraries/{library_id}/entities", f"{BASE_URL}/libraries/{library_id}/entities",
json=new_entity, json=new_entity,
params={"trigger_webhooks_flag": str(not without_webhooks).lower()}, params={"trigger_webhooks_flag": str(not without_webhooks).lower()},
@ -1065,7 +1065,7 @@ def watch(
logger.info(f"Watching library {library_id} for changes...") logger.info(f"Watching library {library_id} for changes...")
# Get the library # Get the library
response = requests.get(f"{BASE_URL}/libraries/{library_id}") response = httpx.get(f"{BASE_URL}/libraries/{library_id}")
if response.status_code != 200: if response.status_code != 200:
print(f"Error: Library with id {library_id} not found.") print(f"Error: Library with id {library_id} not found.")
raise typer.Exit(code=1) raise typer.Exit(code=1)

View File

@ -6,7 +6,7 @@ from datetime import datetime, timedelta
from typing import List from typing import List
# Third-party imports # Third-party imports
import requests # 替换 httpx import httpx
import typer import typer
# Local imports # Local imports
@ -35,16 +35,16 @@ logging.basicConfig(
) )
# Optionally, you can set the logging level for specific libraries # Optionally, you can set the logging level for specific libraries
logging.getLogger("requests").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("typer").setLevel(logging.ERROR) logging.getLogger("typer").setLevel(logging.ERROR)
def check_server_health(): def check_server_health():
"""Check if the server is running and healthy.""" """Check if the server is running and healthy."""
try: try:
response = requests.get(f"{BASE_URL}/health", timeout=5) response = httpx.get(f"{BASE_URL}/health", timeout=5)
return response.status_code == 200 return response.status_code == 200
except requests.RequestException: except httpx.RequestException:
return False return False
@ -106,7 +106,7 @@ def get_or_create_default_library():
Ensure the library has at least one folder. Ensure the library has at least one folder.
""" """
from .cmds.plugin import bind from .cmds.plugin import bind
response = requests.get(f"{BASE_URL}/libraries") response = httpx.get(f"{BASE_URL}/libraries")
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return None return None
@ -118,7 +118,7 @@ def get_or_create_default_library():
if not default_library: if not default_library:
# Create the default library if it doesn't exist # Create the default library if it doesn't exist
response = requests.post( response = httpx.post(
f"{BASE_URL}/libraries", f"{BASE_URL}/libraries",
json={"name": settings.default_library, "folders": []}, json={"name": settings.default_library, "folders": []},
) )
@ -142,7 +142,7 @@ def get_or_create_default_library():
screenshots_dir.stat().st_mtime screenshots_dir.stat().st_mtime
).isoformat(), ).isoformat(),
} }
response = requests.post( response = httpx.post(
f"{BASE_URL}/libraries/{default_library['id']}/folders", f"{BASE_URL}/libraries/{default_library['id']}/folders",
json={"folders": [folder]}, json={"folders": [folder]},
) )
@ -190,7 +190,7 @@ def reindex_default_library(
from .cmds.library import reindex from .cmds.library import reindex
# Get the default library # Get the default library
response = requests.get(f"{BASE_URL}/libraries") response = httpx.get(f"{BASE_URL}/libraries")
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return return

View File

@ -47,7 +47,6 @@ dependencies = [
"mss", "mss",
"sqlite_vec", "sqlite_vec",
"watchdog", "watchdog",
"requests",
] ]
[project.urls] [project.urls]