diff --git a/memos/config.py b/memos/config.py index d2beff7..79799c1 100644 --- a/memos/config.py +++ b/memos/config.py @@ -7,7 +7,7 @@ from pydantic_settings import ( SettingsConfigDict, YamlConfigSettingsSource, ) -from pydantic import BaseModel +from pydantic import BaseModel, SecretStr import yaml from collections import OrderedDict @@ -18,7 +18,7 @@ class VLMSettings(BaseModel): endpoint: str = "http://localhost:11434" token: str = "" concurrency: int = 4 - force_jpeg: bool = False # Add this line + force_jpeg: bool = False class OCRSettings(BaseModel): @@ -28,6 +28,7 @@ class OCRSettings(BaseModel): concurrency: int = 4 use_local: bool = True use_gpu: bool = False + force_jpeg: bool = False class EmbeddingSettings(BaseModel): @@ -65,9 +66,11 @@ class Settings(BaseSettings): # Embedding settings embedding: EmbeddingSettings = EmbeddingSettings() - # New batchsize setting batchsize: int = 4 + auth_username: str = "admin" + auth_password: SecretStr = SecretStr("changeme") + @classmethod def settings_customise_sources( cls, @@ -77,7 +80,10 @@ class Settings(BaseSettings): dotenv_settings: PydanticBaseSettingsSource, file_secret_settings: PydanticBaseSettingsSource, ) -> Tuple[PydanticBaseSettingsSource, ...]: - return (env_settings, YamlConfigSettingsSource(settings_cls),) + return ( + env_settings, + YamlConfigSettingsSource(settings_cls), + ) def dict_representer(dumper, data): diff --git a/memos/server.py b/memos/server.py index d1b0617..9702b68 100644 --- a/memos/server.py +++ b/memos/server.py @@ -6,16 +6,17 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, JSONResponse from fastapi.encoders import jsonable_encoder +from fastapi.security import HTTPBasic, HTTPBasicCredentials from sqlalchemy.orm import Session from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from typing import List, Annotated from pathlib import Path import asyncio -import logging # Import logging module +import logging import cv2 from PIL import Image -from .read_metadata import read_metadata +from secrets import compare_digest import typesense @@ -43,14 +44,12 @@ from .schemas import ( EntitySearchResult, SearchResult, ) - -# Import the logging configuration +from .read_metadata import read_metadata from .logging_config import LOGGING_CONFIG -# Configure logging to include datetime -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO -) +# Initialize FastAPI app and other global variables +app = FastAPI() +security = HTTPBasic() engine = create_engine(f"sqlite:///{get_database_path()}") SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -70,9 +69,6 @@ client = typesense.Client( } ) -app = FastAPI() - - app.add_middleware( CORSMiddleware, allow_origins=["*"], # Adjust this as needed @@ -111,8 +107,34 @@ async def favicon_ico(): return FileResponse(os.path.join(current_dir, "static/favicon.png")) +def is_auth_enabled(): + return bool(settings.auth_username and settings.auth_password.get_secret_value()) + + +def authenticate(credentials: HTTPBasicCredentials = Depends(security)): + if not is_auth_enabled(): + return None + correct_username = compare_digest(credentials.username, settings.auth_username) + correct_password = compare_digest( + credentials.password, settings.auth_password.get_secret_value() + ) + if not (correct_username and correct_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Basic"}, + ) + return credentials.username + + +def optional_auth(credentials: HTTPBasicCredentials = Depends(security)): + if is_auth_enabled(): + return authenticate(credentials) + return None + + @app.get("/") -async def serve_spa(): +async def serve_spa(username: str = Depends(optional_auth)): return FileResponse(os.path.join(current_dir, "static/app.html")) @@ -602,7 +624,7 @@ def add_library_plugin( @app.delete( "/libraries/{library_id}/plugins/{plugin_id}", status_code=status.HTTP_204_NO_CONTENT, - tags=["plugin"] + tags=["plugin"], ) def delete_library_plugin( library_id: int, plugin_id: int, db: Session = Depends(get_db) @@ -610,15 +632,13 @@ def delete_library_plugin( library = crud.get_library_by_id(library_id, db) if library is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Library not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Library not found" ) plugin = crud.get_plugin_by_id(plugin_id, db) if plugin is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Plugin not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found" ) crud.remove_plugin_from_library(library_id, plugin_id, db)