feat: enable config.yaml

This commit is contained in:
arkohut 2024-09-01 19:01:49 +08:00
parent 1a08a44a4d
commit d3b45ad197
3 changed files with 87 additions and 31 deletions

View File

@ -1,7 +1,15 @@
import os import os
from pathlib import Path from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict from typing import Tuple, Type
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
from pydantic import BaseModel from pydantic import BaseModel
import yaml
from collections import OrderedDict
class VLMSettings(BaseModel): class VLMSettings(BaseModel):
@ -42,7 +50,7 @@ class Settings(BaseSettings):
typesense_collection_name: str = "entities" typesense_collection_name: str = "entities"
# Server settings # Server settings
server_host: str = "0.0.0.0" # Add this line server_host: str = "0.0.0.0"
server_port: int = 8080 server_port: int = 8080
# VLM plugin settings # VLM plugin settings
@ -54,6 +62,42 @@ class Settings(BaseSettings):
# Embedding settings # Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings() embedding: EmbeddingSettings = EmbeddingSettings()
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (env_settings, YamlConfigSettingsSource(settings_cls),)
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
yaml.add_representer(OrderedDict, dict_representer)
def create_default_config():
config_path = Path.home() / ".memos" / "config.yaml"
if not config_path.exists():
settings = Settings()
os.makedirs(config_path.parent, exist_ok=True)
with open(config_path, "w") as f:
# Convert settings to a dictionary and ensure order
settings_dict = settings.model_dump()
ordered_settings = OrderedDict(
(key, settings_dict[key]) for key in settings.model_fields.keys()
)
yaml.dump(ordered_settings, f, Dumper=yaml.Dumper)
# Create default config if it doesn't exist
create_default_config()
settings = Settings() settings = Settings()
@ -63,6 +107,7 @@ os.makedirs(settings.base_dir, exist_ok=True)
# Global variable for Typesense collection name # Global variable for Typesense collection name
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
# Function to get the database path from environment variable or default # Function to get the database path from environment variable or default
def get_database_path(): def get_database_path():
return settings.database_path return settings.database_path

View File

@ -4,8 +4,6 @@ from typing import Optional
import httpx import httpx
import json import json
import base64 import base64
import io
import os
from PIL import Image from PIL import Image
from fastapi import APIRouter, FastAPI, Request, HTTPException from fastapi import APIRouter, FastAPI, Request, HTTPException
@ -14,10 +12,7 @@ from memos.schemas import Entity, MetadataType
METADATA_FIELD_NAME = "ocr_result" METADATA_FIELD_NAME = "ocr_result"
PLUGIN_NAME = "ocr" PLUGIN_NAME = "ocr"
router = APIRouter( router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
tags=[PLUGIN_NAME],
responses={404: {"description": "Not found"}}
)
endpoint = None endpoint = None
token = None token = None
concurrency = None concurrency = None
@ -88,9 +83,8 @@ async def ocr(entity: Entity, request: Request):
ocr_result = await predict(entity.filepath) ocr_result = await predict(entity.filepath)
print(ocr_result)
if ocr_result is None or not ocr_result: if ocr_result is None or not ocr_result:
print(f"No OCR result found for file: {entity.filepath}") logger.info(f"No OCR result found for file: {entity.filepath}")
return {METADATA_FIELD_NAME: "{}"} return {METADATA_FIELD_NAME: "{}"}
# Call the URL to patch the entity's metadata # Call the URL to patch the entity's metadata
@ -133,9 +127,10 @@ def init_plugin(config):
concurrency = config.concurrency concurrency = config.concurrency
semaphore = asyncio.Semaphore(concurrency) semaphore = asyncio.Semaphore(concurrency)
print(f"Endpoint: {endpoint}") logger.info("OCR plugin initialized")
print(f"Token: {token}") logger.info(f"Endpoint: {endpoint}")
print(f"Concurrency: {concurrency}") logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}")
if __name__ == "__main__": if __name__ == "__main__":
@ -167,4 +162,4 @@ if __name__ == "__main__":
app = FastAPI() app = FastAPI()
app.include_router(router) app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port) uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -13,10 +13,7 @@ import os
PLUGIN_NAME = "vlm" PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容" PROMPT = "描述这张图片的内容"
router = APIRouter( router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
tags=[PLUGIN_NAME],
responses={404: {"description": "Not found"}}
)
modelname = None modelname = None
endpoint = None endpoint = None
@ -56,7 +53,11 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
choices = result.get("choices", []) choices = result.get("choices", [])
if choices and "message" in choices[0] and "content" in choices[0]["message"]: if (
choices
and "message" in choices[0]
and "content" in choices[0]["message"]
):
return choices[0]["message"]["content"] return choices[0]["message"]["content"]
return "" return ""
except Exception as e: except Exception as e:
@ -70,19 +71,21 @@ async def predict(
img_base64 = image2base64(img_path) img_base64 = image2base64(img_path)
if not img_base64: if not img_base64:
return None return None
# Get the file extension # Get the file extension
_, file_extension = os.path.splitext(img_path) _, file_extension = os.path.splitext(img_path)
file_extension = file_extension.lower()[1:] # Remove the dot and convert to lowercase file_extension = file_extension.lower()[
1:
] # Remove the dot and convert to lowercase
# Determine the MIME type # Determine the MIME type
mime_types = { mime_types = {
'png': 'image/png', "png": "image/png",
'jpg': 'image/jpeg', "jpg": "image/jpeg",
'jpeg': 'image/jpeg', "jpeg": "image/jpeg",
'webp': 'image/webp' "webp": "image/webp",
} }
mime_type = mime_types.get(file_extension, 'image/jpeg') mime_type = mime_types.get(file_extension, "image/jpeg")
request_data = { request_data = {
"model": modelname, "model": modelname,
@ -192,17 +195,30 @@ def init_plugin(config):
concurrency = config.concurrency concurrency = config.concurrency
semaphore = asyncio.Semaphore(concurrency) semaphore = asyncio.Semaphore(concurrency)
# Print the parameters
logger.info("VLM plugin initialized")
logger.info(f"Model Name: {modelname}")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}")
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
from fastapi import FastAPI from fastapi import FastAPI
parser = argparse.ArgumentParser(description="VLM Plugin Configuration") parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
parser.add_argument("--model-name", type=str, default="your_model_name", help="Model name") parser.add_argument(
parser.add_argument("--endpoint", type=str, default="your_endpoint", help="Endpoint URL") "--model-name", type=str, default="your_model_name", help="Model name"
)
parser.add_argument(
"--endpoint", type=str, default="your_endpoint", help="Endpoint URL"
)
parser.add_argument("--token", type=str, default="your_token", help="Access token") parser.add_argument("--token", type=str, default="your_token", help="Access token")
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level") parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args() args = parser.parse_args()
@ -217,4 +233,4 @@ if __name__ == "__main__":
app = FastAPI() app = FastAPI()
app.include_router(router) app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port) uvicorn.run(app, host="0.0.0.0", port=args.port)