From d3b45ad1971fd6e0c12e9a0ffb3626c984cceec0 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Sun, 1 Sep 2024 19:01:49 +0800 Subject: [PATCH] feat: enable config.yaml --- memos/config.py | 51 ++++++++++++++++++++++++++++++++++++--- memos/plugins/ocr/main.py | 19 ++++++--------- memos/plugins/vlm/main.py | 48 ++++++++++++++++++++++++------------ 3 files changed, 87 insertions(+), 31 deletions(-) diff --git a/memos/config.py b/memos/config.py index eb7047f..4257d33 100644 --- a/memos/config.py +++ b/memos/config.py @@ -1,7 +1,15 @@ import os from pathlib import Path -from pydantic_settings import BaseSettings, SettingsConfigDict +from typing import Tuple, Type +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, + YamlConfigSettingsSource, +) from pydantic import BaseModel +import yaml +from collections import OrderedDict class VLMSettings(BaseModel): @@ -42,7 +50,7 @@ class Settings(BaseSettings): typesense_collection_name: str = "entities" # Server settings - server_host: str = "0.0.0.0" # Add this line + server_host: str = "0.0.0.0" server_port: int = 8080 # VLM plugin settings @@ -54,6 +62,42 @@ class Settings(BaseSettings): # Embedding settings embedding: EmbeddingSettings = EmbeddingSettings() + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return (env_settings, YamlConfigSettingsSource(settings_cls),) + + +def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + +yaml.add_representer(OrderedDict, dict_representer) + + +def create_default_config(): + config_path = Path.home() / ".memos" / "config.yaml" + if not config_path.exists(): + settings = Settings() + os.makedirs(config_path.parent, exist_ok=True) + with open(config_path, "w") as f: + # Convert settings to a dictionary and ensure order + settings_dict = settings.model_dump() + ordered_settings = OrderedDict( + (key, settings_dict[key]) for key in settings.model_fields.keys() + ) + yaml.dump(ordered_settings, f, Dumper=yaml.Dumper) + + +# Create default config if it doesn't exist +create_default_config() + settings = Settings() @@ -63,6 +107,7 @@ os.makedirs(settings.base_dir, exist_ok=True) # Global variable for Typesense collection name TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name + # Function to get the database path from environment variable or default def get_database_path(): - return settings.database_path \ No newline at end of file + return settings.database_path diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py index 073e8bb..674f4e8 100644 --- a/memos/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -4,8 +4,6 @@ from typing import Optional import httpx import json import base64 -import io -import os from PIL import Image from fastapi import APIRouter, FastAPI, Request, HTTPException @@ -14,10 +12,7 @@ from memos.schemas import Entity, MetadataType METADATA_FIELD_NAME = "ocr_result" PLUGIN_NAME = "ocr" -router = APIRouter( - tags=[PLUGIN_NAME], - responses={404: {"description": "Not found"}} -) +router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}}) endpoint = None token = None concurrency = None @@ -88,9 +83,8 @@ async def ocr(entity: Entity, request: Request): ocr_result = await predict(entity.filepath) - print(ocr_result) if ocr_result is None or not ocr_result: - print(f"No OCR result found for file: {entity.filepath}") + logger.info(f"No OCR result found for file: {entity.filepath}") return {METADATA_FIELD_NAME: "{}"} # Call the URL to patch the entity's metadata @@ -133,9 +127,10 @@ def init_plugin(config): concurrency = config.concurrency semaphore = asyncio.Semaphore(concurrency) - print(f"Endpoint: {endpoint}") - print(f"Token: {token}") - print(f"Concurrency: {concurrency}") + logger.info("OCR plugin initialized") + logger.info(f"Endpoint: {endpoint}") + logger.info(f"Token: {token}") + logger.info(f"Concurrency: {concurrency}") if __name__ == "__main__": @@ -167,4 +162,4 @@ if __name__ == "__main__": app = FastAPI() app.include_router(router) - uvicorn.run(app, host="0.0.0.0", port=args.port) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index 30286ae..39bfd4d 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -13,10 +13,7 @@ import os PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" -router = APIRouter( - tags=[PLUGIN_NAME], - responses={404: {"description": "Not found"}} -) +router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}}) modelname = None endpoint = None @@ -56,7 +53,11 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N response.raise_for_status() result = response.json() choices = result.get("choices", []) - if choices and "message" in choices[0] and "content" in choices[0]["message"]: + if ( + choices + and "message" in choices[0] + and "content" in choices[0]["message"] + ): return choices[0]["message"]["content"] return "" except Exception as e: @@ -70,19 +71,21 @@ async def predict( img_base64 = image2base64(img_path) if not img_base64: return None - + # Get the file extension _, file_extension = os.path.splitext(img_path) - file_extension = file_extension.lower()[1:] # Remove the dot and convert to lowercase + file_extension = file_extension.lower()[ + 1: + ] # Remove the dot and convert to lowercase # Determine the MIME type mime_types = { - 'png': 'image/png', - 'jpg': 'image/jpeg', - 'jpeg': 'image/jpeg', - 'webp': 'image/webp' + "png": "image/png", + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "webp": "image/webp", } - mime_type = mime_types.get(file_extension, 'image/jpeg') + mime_type = mime_types.get(file_extension, "image/jpeg") request_data = { "model": modelname, @@ -192,17 +195,30 @@ def init_plugin(config): concurrency = config.concurrency semaphore = asyncio.Semaphore(concurrency) + # Print the parameters + logger.info("VLM plugin initialized") + logger.info(f"Model Name: {modelname}") + logger.info(f"Endpoint: {endpoint}") + logger.info(f"Token: {token}") + logger.info(f"Concurrency: {concurrency}") + if __name__ == "__main__": import argparse from fastapi import FastAPI parser = argparse.ArgumentParser(description="VLM Plugin Configuration") - parser.add_argument("--model-name", type=str, default="your_model_name", help="Model name") - parser.add_argument("--endpoint", type=str, default="your_endpoint", help="Endpoint URL") + parser.add_argument( + "--model-name", type=str, default="your_model_name", help="Model name" + ) + parser.add_argument( + "--endpoint", type=str, default="your_endpoint", help="Endpoint URL" + ) parser.add_argument("--token", type=str, default="your_token", help="Access token") parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level") - parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") + parser.add_argument( + "--port", type=int, default=8000, help="Port to run the server on" + ) args = parser.parse_args() @@ -217,4 +233,4 @@ if __name__ == "__main__": app = FastAPI() app.include_router(router) - uvicorn.run(app, host="0.0.0.0", port=args.port) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=args.port)