From a60488e80cc4474bfbe10c176eaae090f5332f33 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Sat, 19 Oct 2024 14:41:15 +0800 Subject: [PATCH] feat(config): use default config file --- .gitignore | 1 + memos/commands.py | 4 +-- memos/config.py | 55 +++++++++++++++++++---------------- memos/default_config.yaml | 60 +++++++++++++++++++++++++++++++++++++++ memos/record.py | 2 +- 5 files changed, 94 insertions(+), 28 deletions(-) create mode 100644 memos/default_config.yaml diff --git a/.gitignore b/.gitignore index ccf27c9..b08b64e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ typesense-data/ test-data/ memos/static/ db/ +memos/plugins/ocr/temp_ppocr.yaml diff --git a/memos/commands.py b/memos/commands.py index 1ca24e9..23b8735 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -29,7 +29,7 @@ app = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]}) app.add_typer(plugin_app, name="plugin") app.add_typer(lib_app, name="lib") -BASE_URL = f"http://{settings.server_host}:{settings.server_port}" +BASE_URL = settings.server_endpoint # Configure logging logging.basicConfig( @@ -207,7 +207,7 @@ def record( """ Record screenshots of the screen. """ - base_dir = os.path.expanduser(base_dir) if base_dir else settings.screenshots_dir + base_dir = os.path.expanduser(base_dir) if base_dir else settings.resolved_screenshots_dir previous_hashes = load_previous_hashes(base_dir) if once: diff --git a/memos/config.py b/memos/config.py index cf604c8..6940a8e 100644 --- a/memos/config.py +++ b/memos/config.py @@ -1,4 +1,5 @@ import os +import shutil from pathlib import Path from typing import Tuple, Type, List from pydantic_settings import ( @@ -65,10 +66,10 @@ class Settings(BaseSettings): env_prefix="MEMOS_", ) - base_dir: str = str(Path.home() / ".memos") - database_path: str = os.path.join(base_dir, "database.db") + base_dir: str = "~/.memos" + database_path: str = "database.db" default_library: str = "screenshots" - screenshots_dir: str = os.path.join(base_dir, "screenshots") + screenshots_dir: str = "screenshots" # Server settings server_host: str = "127.0.0.1" @@ -91,7 +92,7 @@ class Settings(BaseSettings): auth_username: str = "admin" auth_password: SecretStr = SecretStr("changeme") - default_plugins: List[str] = ["builtin_vlm", "builtin_ocr"] + default_plugins: List[str] = ["builtin_ocr"] @classmethod def settings_customise_sources( @@ -107,6 +108,23 @@ class Settings(BaseSettings): YamlConfigSettingsSource(settings_cls), ) + @property + def resolved_base_dir(self) -> Path: + return Path(self.base_dir).expanduser().resolve() + + @property + def resolved_database_path(self) -> Path: + return self.resolved_base_dir / self.database_path + + @property + def resolved_screenshots_dir(self) -> Path: + return self.resolved_base_dir / self.screenshots_dir + + @property + def server_endpoint(self) -> str: + host = "127.0.0.1" if self.server_host == "0.0.0.0" else self.server_host + return f"http://{host}:{self.server_port}" + def dict_representer(dumper, data): return dumper.represent_dict(data.items()) @@ -133,35 +151,19 @@ yaml.add_representer(SecretStr, secret_str_representer) def create_default_config(): config_path = Path.home() / ".memos" / "config.yaml" if not config_path.exists(): - settings = Settings() + template_path = Path(__file__).parent / "default_config.yaml" os.makedirs(config_path.parent, exist_ok=True) - - # 将设置转换为字典并确保顺序 - settings_dict = settings.model_dump() - ordered_settings = OrderedDict( - (key, settings_dict[key]) for key in settings.model_fields.keys() - ) - - # 使用 io.StringIO 作为中间步骤 - with io.StringIO() as string_buffer: - yaml.dump( - ordered_settings, string_buffer, allow_unicode=True, Dumper=yaml.Dumper - ) - yaml_content = string_buffer.getvalue() - - # 将内容写入文件,确保使用 UTF-8 编码 - with open(config_path, "w", encoding="utf-8") as f: - f.write(yaml_content) + shutil.copy(template_path, config_path) + print(f"Created default configuration at {config_path}") # Create default config if it doesn't exist create_default_config() - settings = Settings() # Define the default database path -os.makedirs(settings.base_dir, exist_ok=True) +os.makedirs(settings.resolved_base_dir, exist_ok=True) # Global variable for Typesense collection name TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name @@ -169,7 +171,7 @@ TYPESENSE_COLLECTION_NAME = settings.typesense.collection_name # Function to get the database path from environment variable or default def get_database_path(): - return settings.database_path + return str(settings.resolved_database_path) def format_value(value): @@ -197,6 +199,9 @@ def display_config(): typer.echo("Current configuration settings:") for key, value in config_dict.items(): formatted_value = format_value(value) + if key in ["base_dir", "database_path", "screenshots_dir"]: + resolved_value = getattr(settings, f"resolved_{key}") + formatted_value += f" (resolved: {resolved_value})" if "\n" in formatted_value: typer.echo(f"{key}:") for line in formatted_value.split("\n"): diff --git a/memos/default_config.yaml b/memos/default_config.yaml new file mode 100644 index 0000000..34e55c0 --- /dev/null +++ b/memos/default_config.yaml @@ -0,0 +1,60 @@ +base_dir: ~/.memos +database_path: database.db +default_library: screenshots +screenshots_dir: screenshots + +server_host: 127.0.0.1 +server_port: 8839 + +# using ollama as the vlm server +vlm: + concurrency: 1 + enabled: true + endpoint: http://localhost:11434 + force_jpeg: true + modelname: minicpm-v + prompt: 请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等 + token: '' + +# using local ocr +ocr: + concurrency: 1 + enabled: true + endpoint: http://localhost:5555/predict # this is not used + force_jpeg: false + token: '' + use_local: true + +# using local embedding +embedding: + enabled: true + endpoint: http://localhost:11434/api/embed # this is not used + model: arkohut/jina-embeddings-v2-base-zh + num_dim: 768 + use_local: true + use_modelscope: false + +# using ollama embedding +# embedding: +# enabled: true +# endpoint: http://localhost:11434/api/embed # this is not used +# model: arkohut/gte-qwen2-1.5b-instruct:q8_0 +# num_dim: 1536 +# use_local: false +# use_modelscope: false + +typesense: + api_key: xyz + collection_name: entities + connection_timeout_seconds: 10 + enabled: false + host: localhost + port: '8108' + protocol: http + +batchsize: 1 +auth_username: admin +auth_password: changeme +default_plugins: +- builtin_ocr +# - builtin_vlm diff --git a/memos/record.py b/memos/record.py index ea93d1b..e3f2c2c 100644 --- a/memos/record.py +++ b/memos/record.py @@ -337,7 +337,7 @@ def main(): args = parser.parse_args() base_dir = ( - os.path.expanduser(args.base_dir) if args.base_dir else settings.screenshots_dir + os.path.expanduser(args.base_dir) if args.base_dir else settings.resolved_screenshots_dir ) previous_hashes = load_previous_hashes(base_dir)