feat(config): use default config file

This commit is contained in:
arkohut 2024-10-19 14:41:15 +08:00
parent 189b82739d
commit a60488e80c
5 changed files with 94 additions and 28 deletions

1
.gitignore vendored
View File

@ -9,3 +9,4 @@ typesense-data/
test-data/
memos/static/
db/
memos/plugins/ocr/temp_ppocr.yaml

View File

@ -29,7 +29,7 @@ app = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]})
app.add_typer(plugin_app, name="plugin")
app.add_typer(lib_app, name="lib")
BASE_URL = f"http://{settings.server_host}:{settings.server_port}"
BASE_URL = settings.server_endpoint
# Configure logging
logging.basicConfig(
@ -207,7 +207,7 @@ def record(
"""
Record screenshots of the screen.
"""
base_dir = os.path.expanduser(base_dir) if base_dir else settings.screenshots_dir
base_dir = os.path.expanduser(base_dir) if base_dir else settings.resolved_screenshots_dir
previous_hashes = load_previous_hashes(base_dir)
if once:

View File

@ -1,4 +1,5 @@
import os
import shutil
from pathlib import Path
from typing import Tuple, Type, List
from pydantic_settings import (
@ -65,10 +66,10 @@ class Settings(BaseSettings):
env_prefix="MEMOS_",
)
base_dir: str = str(Path.home() / ".memos")
database_path: str = os.path.join(base_dir, "database.db")
base_dir: str = "~/.memos"
database_path: str = "database.db"
default_library: str = "screenshots"
screenshots_dir: str = os.path.join(base_dir, "screenshots")
screenshots_dir: str = "screenshots"
# Server settings
server_host: str = "127.0.0.1"
@ -91,7 +92,7 @@ class Settings(BaseSettings):
auth_username: str = "admin"
auth_password: SecretStr = SecretStr("changeme")
default_plugins: List[str] = ["builtin_vlm", "builtin_ocr"]
default_plugins: List[str] = ["builtin_ocr"]
@classmethod
def settings_customise_sources(
@ -107,6 +108,23 @@ class Settings(BaseSettings):
YamlConfigSettingsSource(settings_cls),
)
@property
def resolved_base_dir(self) -> Path:
return Path(self.base_dir).expanduser().resolve()
@property
def resolved_database_path(self) -> Path:
return self.resolved_base_dir / self.database_path
@property
def resolved_screenshots_dir(self) -> Path:
return self.resolved_base_dir / self.screenshots_dir
@property
def server_endpoint(self) -> str:
host = "127.0.0.1" if self.server_host == "0.0.0.0" else self.server_host
return f"http://{host}:{self.server_port}"
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
@ -133,35 +151,19 @@ yaml.add_representer(SecretStr, secret_str_representer)
def create_default_config():
config_path = Path.home() / ".memos" / "config.yaml"
if not config_path.exists():
settings = Settings()
template_path = Path(__file__).parent / "default_config.yaml"
os.makedirs(config_path.parent, exist_ok=True)
# 将设置转换为字典并确保顺序
settings_dict = settings.model_dump()
ordered_settings = OrderedDict(
(key, settings_dict[key]) for key in settings.model_fields.keys()
)
# 使用 io.StringIO 作为中间步骤
with io.StringIO() as string_buffer:
yaml.dump(
ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper
)
yaml_content = string_buffer.getvalue()
# 将内容写入文件,确保使用 UTF-8 编码
with open(config_path, "w", encoding="utf-8") as f:
f.write(yaml_content)
shutil.copy(template_path, config_path)
print(f"Created default configuration at {config_path}")
# Create default config if it doesn't exist
create_default_config()
settings = Settings()
# Define the default database path
os.makedirs(settings.base_dir, exist_ok=True)
os.makedirs(settings.resolved_base_dir, exist_ok=True)
# Global variable for Typesense collection name
TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name
@ -169,7 +171,7 @@ TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name
# Function to get the database path from environment variable or default
def get_database_path():
return settings.database_path
return str(settings.resolved_database_path)
def format_value(value):
@ -197,6 +199,9 @@ def display_config():
typer.echo("Current configuration settings:")
for key, value in config_dict.items():
formatted_value = format_value(value)
if key in ["base_dir", "database_path", "screenshots_dir"]:
resolved_value = getattr(settings, f"resolved_{key}")
formatted_value += f" (resolved: {resolved_value})"
if "\n" in formatted_value:
typer.echo(f"{key}:")
for line in formatted_value.split("\n"):

60
memos/default_config.yaml Normal file
View File

@ -0,0 +1,60 @@
base_dir: ~/.memos
database_path: database.db
default_library: screenshots
screenshots_dir: screenshots
server_host: 127.0.0.1
server_port: 8839
# using ollama as the vlm server
vlm:
concurrency: 1
enabled: true
endpoint: http://localhost:11434
force_jpeg: true
modelname: minicpm-v
prompt: 请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等
token: ''
# using local ocr
ocr:
concurrency: 1
enabled: true
endpoint: http://localhost:5555/predict # this is not used
force_jpeg: false
token: ''
use_local: true
# using local embedding
embedding:
enabled: true
endpoint: http://localhost:11434/api/embed # this is not used
model: arkohut/jina-embeddings-v2-base-zh
num_dim: 768
use_local: true
use_modelscope: false
# using ollama embedding
# embedding:
# enabled: true
# endpoint: http://localhost:11434/api/embed # this is not used
# model: arkohut/gte-qwen2-1.5b-instruct:q8_0
# num_dim: 1536
# use_local: false
# use_modelscope: false
typesense:
api_key: xyz
collection_name: entities
connection_timeout_seconds: 10
enabled: false
host: localhost
port: '8108'
protocol: http
batchsize: 1
auth_username: admin
auth_password: changeme
default_plugins:
- builtin_ocr
# - builtin_vlm

View File

@ -337,7 +337,7 @@ def main():
args = parser.parse_args()
base_dir = (
os.path.expanduser(args.base_dir) if args.base_dir else settings.screenshots_dir
os.path.expanduser(args.base_dir) if args.base_dir else settings.resolved_screenshots_dir
)
previous_hashes = load_previous_hashes(base_dir)