mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat: enable baisc auth
This commit is contained in:
parent
f1f77bb906
commit
dbbc2792ef
@ -7,7 +7,7 @@ from pydantic_settings import (
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, SecretStr
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
|
||||
@ -18,7 +18,7 @@ class VLMSettings(BaseModel):
|
||||
endpoint: str = "http://localhost:11434"
|
||||
token: str = ""
|
||||
concurrency: int = 4
|
||||
force_jpeg: bool = False # Add this line
|
||||
force_jpeg: bool = False
|
||||
|
||||
|
||||
class OCRSettings(BaseModel):
|
||||
@ -28,6 +28,7 @@ class OCRSettings(BaseModel):
|
||||
concurrency: int = 4
|
||||
use_local: bool = True
|
||||
use_gpu: bool = False
|
||||
force_jpeg: bool = False
|
||||
|
||||
|
||||
class EmbeddingSettings(BaseModel):
|
||||
@ -65,9 +66,11 @@ class Settings(BaseSettings):
|
||||
# Embedding settings
|
||||
embedding: EmbeddingSettings = EmbeddingSettings()
|
||||
|
||||
# New batchsize setting
|
||||
batchsize: int = 4
|
||||
|
||||
auth_username: str = "admin"
|
||||
auth_password: SecretStr = SecretStr("changeme")
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
@ -77,7 +80,10 @@ class Settings(BaseSettings):
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (env_settings, YamlConfigSettingsSource(settings_cls),)
|
||||
return (
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
|
||||
|
||||
def dict_representer(dumper, data):
|
||||
|
@ -6,16 +6,17 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import List, Annotated
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import logging # Import logging module
|
||||
import logging
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from .read_metadata import read_metadata
|
||||
from secrets import compare_digest
|
||||
|
||||
import typesense
|
||||
|
||||
@ -43,14 +44,12 @@ from .schemas import (
|
||||
EntitySearchResult,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
# Import the logging configuration
|
||||
from .read_metadata import read_metadata
|
||||
from .logging_config import LOGGING_CONFIG
|
||||
|
||||
# Configure logging to include datetime
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
# Initialize FastAPI app and other global variables
|
||||
app = FastAPI()
|
||||
security = HTTPBasic()
|
||||
|
||||
engine = create_engine(f"sqlite:///{get_database_path()}")
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
@ -70,9 +69,6 @@ client = typesense.Client(
|
||||
}
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Adjust this as needed
|
||||
@ -111,8 +107,34 @@ async def favicon_ico():
|
||||
return FileResponse(os.path.join(current_dir, "static/favicon.png"))
|
||||
|
||||
|
||||
def is_auth_enabled():
|
||||
return bool(settings.auth_username and settings.auth_password.get_secret_value())
|
||||
|
||||
|
||||
def authenticate(credentials: HTTPBasicCredentials = Depends(security)):
|
||||
if not is_auth_enabled():
|
||||
return None
|
||||
correct_username = compare_digest(credentials.username, settings.auth_username)
|
||||
correct_password = compare_digest(
|
||||
credentials.password, settings.auth_password.get_secret_value()
|
||||
)
|
||||
if not (correct_username and correct_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
return credentials.username
|
||||
|
||||
|
||||
def optional_auth(credentials: HTTPBasicCredentials = Depends(security)):
|
||||
if is_auth_enabled():
|
||||
return authenticate(credentials)
|
||||
return None
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def serve_spa():
|
||||
async def serve_spa(username: str = Depends(optional_auth)):
|
||||
return FileResponse(os.path.join(current_dir, "static/app.html"))
|
||||
|
||||
|
||||
@ -602,7 +624,7 @@ def add_library_plugin(
|
||||
@app.delete(
|
||||
"/libraries/{library_id}/plugins/{plugin_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
tags=["plugin"]
|
||||
tags=["plugin"],
|
||||
)
|
||||
def delete_library_plugin(
|
||||
library_id: int, plugin_id: int, db: Session = Depends(get_db)
|
||||
@ -610,15 +632,13 @@ def delete_library_plugin(
|
||||
library = crud.get_library_by_id(library_id, db)
|
||||
if library is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Library not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
|
||||
)
|
||||
|
||||
plugin = crud.get_plugin_by_id(plugin_id, db)
|
||||
if plugin is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Plugin not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found"
|
||||
)
|
||||
|
||||
crud.remove_plugin_from_library(library_id, plugin_id, db)
|
||||
|
Loading…
x
Reference in New Issue
Block a user