feat: enable baisc auth

This commit is contained in:
arkohut 2024-09-06 18:18:18 +08:00
parent f1f77bb906
commit dbbc2792ef
2 changed files with 47 additions and 21 deletions

View File

@ -7,7 +7,7 @@ from pydantic_settings import (
SettingsConfigDict,
YamlConfigSettingsSource,
)
from pydantic import BaseModel
from pydantic import BaseModel, SecretStr
import yaml
from collections import OrderedDict
@ -18,7 +18,7 @@ class VLMSettings(BaseModel):
endpoint: str = "http://localhost:11434"
token: str = ""
concurrency: int = 4
force_jpeg: bool = False # Add this line
force_jpeg: bool = False
class OCRSettings(BaseModel):
@ -28,6 +28,7 @@ class OCRSettings(BaseModel):
concurrency: int = 4
use_local: bool = True
use_gpu: bool = False
force_jpeg: bool = False
class EmbeddingSettings(BaseModel):
@ -65,9 +66,11 @@ class Settings(BaseSettings):
# Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings()
# New batchsize setting
batchsize: int = 4
auth_username: str = "admin"
auth_password: SecretStr = SecretStr("changeme")
@classmethod
def settings_customise_sources(
cls,
@ -77,7 +80,10 @@ class Settings(BaseSettings):
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (env_settings, YamlConfigSettingsSource(settings_cls),)
return (
env_settings,
YamlConfigSettingsSource(settings_cls),
)
def dict_representer(dumper, data):

View File

@ -6,16 +6,17 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from fastapi.encoders import jsonable_encoder
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from sqlalchemy.orm import Session
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from typing import List, Annotated
from pathlib import Path
import asyncio
import logging # Import logging module
import logging
import cv2
from PIL import Image
from .read_metadata import read_metadata
from secrets import compare_digest
import typesense
@ -43,14 +44,12 @@ from .schemas import (
EntitySearchResult,
SearchResult,
)
# Import the logging configuration
from .read_metadata import read_metadata
from .logging_config import LOGGING_CONFIG
# Configure logging to include datetime
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
# Initialize FastAPI app and other global variables
app = FastAPI()
security = HTTPBasic()
engine = create_engine(f"sqlite:///{get_database_path()}")
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@ -70,9 +69,6 @@ client = typesense.Client(
}
)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust this as needed
@ -111,8 +107,34 @@ async def favicon_ico():
return FileResponse(os.path.join(current_dir, "static/favicon.png"))
def is_auth_enabled():
return bool(settings.auth_username and settings.auth_password.get_secret_value())
def authenticate(credentials: HTTPBasicCredentials = Depends(security)):
if not is_auth_enabled():
return None
correct_username = compare_digest(credentials.username, settings.auth_username)
correct_password = compare_digest(
credentials.password, settings.auth_password.get_secret_value()
)
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return credentials.username
def optional_auth(credentials: HTTPBasicCredentials = Depends(security)):
if is_auth_enabled():
return authenticate(credentials)
return None
@app.get("/")
async def serve_spa():
async def serve_spa(username: str = Depends(optional_auth)):
return FileResponse(os.path.join(current_dir, "static/app.html"))
@ -602,7 +624,7 @@ def add_library_plugin(
@app.delete(
"/libraries/{library_id}/plugins/{plugin_id}",
status_code=status.HTTP_204_NO_CONTENT,
tags=["plugin"]
tags=["plugin"],
)
def delete_library_plugin(
library_id: int, plugin_id: int, db: Session = Depends(get_db)
@ -610,15 +632,13 @@ def delete_library_plugin(
library = crud.get_library_by_id(library_id, db)
if library is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Library not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
)
plugin = crud.get_plugin_by_id(plugin_id, db)
if plugin is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Plugin not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found"
)
crud.remove_plugin_from_library(library_id, plugin_id, db)