feat: enable baisc auth

This commit is contained in:
arkohut 2024-09-06 18:18:18 +08:00
parent f1f77bb906
commit dbbc2792ef
2 changed files with 47 additions and 21 deletions

View File

@ -7,7 +7,7 @@ from pydantic_settings import (
SettingsConfigDict, SettingsConfigDict,
YamlConfigSettingsSource, YamlConfigSettingsSource,
) )
from pydantic import BaseModel from pydantic import BaseModel, SecretStr
import yaml import yaml
from collections import OrderedDict from collections import OrderedDict
@ -18,7 +18,7 @@ class VLMSettings(BaseModel):
endpoint: str = "http://localhost:11434" endpoint: str = "http://localhost:11434"
token: str = "" token: str = ""
concurrency: int = 4 concurrency: int = 4
force_jpeg: bool = False # Add this line force_jpeg: bool = False
class OCRSettings(BaseModel): class OCRSettings(BaseModel):
@ -28,6 +28,7 @@ class OCRSettings(BaseModel):
concurrency: int = 4 concurrency: int = 4
use_local: bool = True use_local: bool = True
use_gpu: bool = False use_gpu: bool = False
force_jpeg: bool = False
class EmbeddingSettings(BaseModel): class EmbeddingSettings(BaseModel):
@ -65,9 +66,11 @@ class Settings(BaseSettings):
# Embedding settings # Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings() embedding: EmbeddingSettings = EmbeddingSettings()
# New batchsize setting
batchsize: int = 4 batchsize: int = 4
auth_username: str = "admin"
auth_password: SecretStr = SecretStr("changeme")
@classmethod @classmethod
def settings_customise_sources( def settings_customise_sources(
cls, cls,
@ -77,7 +80,10 @@ class Settings(BaseSettings):
dotenv_settings: PydanticBaseSettingsSource, dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource, file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]: ) -> Tuple[PydanticBaseSettingsSource, ...]:
return (env_settings, YamlConfigSettingsSource(settings_cls),) return (
env_settings,
YamlConfigSettingsSource(settings_cls),
)
def dict_representer(dumper, data): def dict_representer(dumper, data):

View File

@ -6,16 +6,17 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse from fastapi.responses import FileResponse, JSONResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from typing import List, Annotated from typing import List, Annotated
from pathlib import Path from pathlib import Path
import asyncio import asyncio
import logging # Import logging module import logging
import cv2 import cv2
from PIL import Image from PIL import Image
from .read_metadata import read_metadata from secrets import compare_digest
import typesense import typesense
@ -43,14 +44,12 @@ from .schemas import (
EntitySearchResult, EntitySearchResult,
SearchResult, SearchResult,
) )
from .read_metadata import read_metadata
# Import the logging configuration
from .logging_config import LOGGING_CONFIG from .logging_config import LOGGING_CONFIG
# Configure logging to include datetime # Initialize FastAPI app and other global variables
logging.basicConfig( app = FastAPI()
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO security = HTTPBasic()
)
engine = create_engine(f"sqlite:///{get_database_path()}") engine = create_engine(f"sqlite:///{get_database_path()}")
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@ -70,9 +69,6 @@ client = typesense.Client(
} }
) )
app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], # Adjust this as needed allow_origins=["*"], # Adjust this as needed
@ -111,8 +107,34 @@ async def favicon_ico():
return FileResponse(os.path.join(current_dir, "static/favicon.png")) return FileResponse(os.path.join(current_dir, "static/favicon.png"))
def is_auth_enabled():
return bool(settings.auth_username and settings.auth_password.get_secret_value())
def authenticate(credentials: HTTPBasicCredentials = Depends(security)):
if not is_auth_enabled():
return None
correct_username = compare_digest(credentials.username, settings.auth_username)
correct_password = compare_digest(
credentials.password, settings.auth_password.get_secret_value()
)
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return credentials.username
def optional_auth(credentials: HTTPBasicCredentials = Depends(security)):
if is_auth_enabled():
return authenticate(credentials)
return None
@app.get("/") @app.get("/")
async def serve_spa(): async def serve_spa(username: str = Depends(optional_auth)):
return FileResponse(os.path.join(current_dir, "static/app.html")) return FileResponse(os.path.join(current_dir, "static/app.html"))
@ -602,7 +624,7 @@ def add_library_plugin(
@app.delete( @app.delete(
"/libraries/{library_id}/plugins/{plugin_id}", "/libraries/{library_id}/plugins/{plugin_id}",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
tags=["plugin"] tags=["plugin"],
) )
def delete_library_plugin( def delete_library_plugin(
library_id: int, plugin_id: int, db: Session = Depends(get_db) library_id: int, plugin_id: int, db: Session = Depends(get_db)
@ -610,15 +632,13 @@ def delete_library_plugin(
library = crud.get_library_by_id(library_id, db) library = crud.get_library_by_id(library_id, db)
if library is None: if library is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
detail="Library not found"
) )
plugin = crud.get_plugin_by_id(plugin_id, db) plugin = crud.get_plugin_by_id(plugin_id, db)
if plugin is None: if plugin is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found"
detail="Plugin not found"
) )
crud.remove_plugin_from_library(library_id, plugin_id, db) crud.remove_plugin_from_library(library_id, plugin_id, db)