feat(typesense): disable typesense by default

This commit is contained in:
arkohut 2024-10-08 18:54:59 +08:00
parent 14ae101f6b
commit b3ebb2c92d
4 changed files with 74 additions and 29 deletions

View File

@ -38,6 +38,16 @@ class EmbeddingSettings(BaseModel):
use_modelscope: bool = False
class TypesenseSettings(BaseModel):
enabled: bool = False
host: str = "localhost"
port: str = "8108"
protocol: str = "http"
api_key: str = "xyz"
connection_timeout_seconds: int = 10
collection_name: str = "entities"
class Settings(BaseSettings):
model_config = SettingsConfigDict(
yaml_file=str(Path.home() / ".memos" / "config.yaml"),
@ -50,13 +60,6 @@ class Settings(BaseSettings):
default_library: str = "screenshots"
screenshots_dir: str = os.path.join(base_dir, "screenshots")
typesense_host: str = "localhost"
typesense_port: str = "8108"
typesense_protocol: str = "http"
typesense_api_key: str = "xyz"
typesense_connection_timeout_seconds: int = 10
typesense_collection_name: str = "entities"
# Server settings
server_host: str = "0.0.0.0"
server_port: int = 8080
@ -70,6 +73,9 @@ class Settings(BaseSettings):
# Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings()
# Typesense settings
typesense: TypesenseSettings = TypesenseSettings()
batchsize: int = 1
auth_username: str = "admin"
@ -136,7 +142,7 @@ settings = Settings()
os.makedirs(settings.base_dir, exist_ok=True)
# Global variable for Typesense collection name
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name
# Function to get the database path from environment variable or default

View File

@ -1,5 +1,11 @@
import typesense
from .config import settings, TYPESENSE_COLLECTION_NAME
import sys
# Check if Typesense is enabled
if not settings.typesense.enabled:
print("Error: Typesense is not enabled. Please enable it in the configuration.")
sys.exit(1)
# Initialize Typesense client
client = typesense.Client(
@ -112,6 +118,10 @@ def update_collection_fields(client, schema):
def init_typesense():
"""Initialize the Typesense collection."""
if not settings.typesense.enabled:
print("Error: Typesense is not enabled. Please enable it in the configuration.")
return False
try:
existing_collections = client.collections.retrieve()
collection_names = [c["name"] for c in existing_collections]
@ -134,6 +144,10 @@ def init_typesense():
if __name__ == "__main__":
import argparse
if not settings.typesense.enabled:
print("Error: Typesense is not enabled. Please enable it in the configuration.")
sys.exit(1)
parser = argparse.ArgumentParser()
parser.add_argument("--force", action="store_true", help="Drop the collection before initializing")
args = parser.parse_args()
@ -145,4 +159,5 @@ if __name__ == "__main__":
except Exception as e:
print(f"Error dropping collection: {e}")
init_typesense()
if not init_typesense():
sys.exit(1)

View File

@ -223,9 +223,9 @@ def init_database():
with engine.connect() as conn:
conn.execute(
DDL(
"""
f"""
CREATE VIRTUAL TABLE IF NOT EXISTS entities_vec USING vec0(
embedding float[768]
embedding float[{settings.embedding.num_dim}]
)
"""
)
@ -430,4 +430,4 @@ def delete_fts_and_vec(mapper, connection, target):
event.listen(EntityModel, "after_insert", update_fts_and_vec)
event.listen(EntityModel, "after_update", update_fts_and_vec)
event.listen(EntityModel, "after_delete", delete_fts_and_vec)
event.listen(EntityModel, "after_delete", delete_fts_and_vec)

View File

@ -17,6 +17,7 @@ import json
import cv2
from PIL import Image
from secrets import compare_digest
import functools
import typesense
@ -58,20 +59,22 @@ engine = create_engine(f"sqlite:///{get_database_path()}")
event.listen(engine, "connect", load_extension)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Initialize Typesense client
client = typesense.Client(
{
"nodes": [
{
"host": settings.typesense_host,
"port": settings.typesense_port,
"protocol": settings.typesense_protocol,
}
],
"api_key": settings.typesense_api_key,
"connection_timeout_seconds": settings.typesense_connection_timeout_seconds,
}
)
# Initialize Typesense client only if enabled
client = None
if settings.typesense.enabled:
client = typesense.Client(
{
"nodes": [
{
"host": settings.typesense.host,
"port": settings.typesense.port,
"protocol": settings.typesense.protocol,
}
],
"api_key": settings.typesense.api_key,
"connection_timeout_seconds": settings.typesense.connection_timeout_seconds,
}
)
app.add_middleware(
CORSMiddleware,
@ -370,11 +373,24 @@ async def update_entity(
return entity
def typesense_required(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
if not settings.typesense.enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Typesense is not enabled",
)
return await func(*args, **kwargs)
return wrapper
@app.post(
"/entities/{entity_id}/index",
status_code=status.HTTP_204_NO_CONTENT,
tags=["entity"],
)
@typesense_required
async def sync_entity_to_typesense(entity_id: int, db: Session = Depends(get_db)):
entity = crud.get_entity_by_id(entity_id, db)
if entity is None:
@ -398,6 +414,7 @@ async def sync_entity_to_typesense(entity_id: int, db: Session = Depends(get_db)
status_code=status.HTTP_204_NO_CONTENT,
tags=["entity"],
)
@typesense_required
async def batch_sync_entities_to_typesense(
entity_ids: List[int], db: Session = Depends(get_db)
):
@ -423,6 +440,7 @@ async def batch_sync_entities_to_typesense(
response_model=EntitySearchResult,
tags=["entity"],
)
@typesense_required
async def get_entity_index(entity_id: int) -> EntityIndexItem:
try:
entity_index_item = indexing.fetch_entity_by_id(client, entity_id)
@ -440,6 +458,7 @@ async def get_entity_index(entity_id: int) -> EntityIndexItem:
status_code=status.HTTP_204_NO_CONTENT,
tags=["entity"],
)
@typesense_required
async def remove_entity_from_typesense(entity_id: int, db: Session = Depends(get_db)):
try:
indexing.remove_entity_by_id(client, entity_id)
@ -456,6 +475,7 @@ async def remove_entity_from_typesense(entity_id: int, db: Session = Depends(get
response_model=List[EntityIndexItem],
tags=["entity"],
)
@typesense_required
def list_entitiy_indices_in_folder(
library_id: int,
folder_id: int,
@ -479,6 +499,7 @@ def list_entitiy_indices_in_folder(
@app.get("/search/v2", response_model=SearchResult, tags=["search"])
@typesense_required
async def search_entities(
q: str,
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
@ -821,9 +842,12 @@ async def search_entities_v2(
def run_server():
print("Database path:", get_database_path())
print(
f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}, Collection Name: {settings.typesense_collection_name}"
)
if settings.typesense.enabled:
print(
f"Typesense connection info: Host: {settings.typesense.host}, Port: {settings.typesense.port}, Protocol: {settings.typesense.protocol}, Collection Name: {settings.typesense.collection_name}"
)
else:
print("Typesense is disabled")
print(f"VLM plugin enabled: {settings.vlm}")
print(f"OCR plugin enabled: {settings.ocr}")