feat: enable config.yaml

This commit is contained in:
arkohut 2024-09-01 19:01:49 +08:00
parent 1a08a44a4d
commit d3b45ad197
3 changed files with 87 additions and 31 deletions

View File

@ -1,7 +1,15 @@
import os
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Tuple, Type
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
from pydantic import BaseModel
import yaml
from collections import OrderedDict
class VLMSettings(BaseModel):
@ -42,7 +50,7 @@ class Settings(BaseSettings):
typesense_collection_name: str = "entities"
# Server settings
server_host: str = "0.0.0.0" # Add this line
server_host: str = "0.0.0.0"
server_port: int = 8080
# VLM plugin settings
@ -54,6 +62,42 @@ class Settings(BaseSettings):
# Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings()
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (env_settings, YamlConfigSettingsSource(settings_cls),)
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
yaml.add_representer(OrderedDict, dict_representer)
def create_default_config():
config_path = Path.home() / ".memos" / "config.yaml"
if not config_path.exists():
settings = Settings()
os.makedirs(config_path.parent, exist_ok=True)
with open(config_path, "w") as f:
# Convert settings to a dictionary and ensure order
settings_dict = settings.model_dump()
ordered_settings = OrderedDict(
(key, settings_dict[key]) for key in settings.model_fields.keys()
)
yaml.dump(ordered_settings, f, Dumper=yaml.Dumper)
# Create default config if it doesn't exist
create_default_config()
settings = Settings()
@ -63,6 +107,7 @@ os.makedirs(settings.base_dir, exist_ok=True)
# Global variable for Typesense collection name
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
# Function to get the database path from environment variable or default
def get_database_path():
return settings.database_path
return settings.database_path

View File

@ -4,8 +4,6 @@ from typing import Optional
import httpx
import json
import base64
import io
import os
from PIL import Image
from fastapi import APIRouter, FastAPI, Request, HTTPException
@ -14,10 +12,7 @@ from memos.schemas import Entity, MetadataType
METADATA_FIELD_NAME = "ocr_result"
PLUGIN_NAME = "ocr"
router = APIRouter(
tags=[PLUGIN_NAME],
responses={404: {"description": "Not found"}}
)
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
endpoint = None
token = None
concurrency = None
@ -88,9 +83,8 @@ async def ocr(entity: Entity, request: Request):
ocr_result = await predict(entity.filepath)
print(ocr_result)
if ocr_result is None or not ocr_result:
print(f"No OCR result found for file: {entity.filepath}")
logger.info(f"No OCR result found for file: {entity.filepath}")
return {METADATA_FIELD_NAME: "{}"}
# Call the URL to patch the entity's metadata
@ -133,9 +127,10 @@ def init_plugin(config):
concurrency = config.concurrency
semaphore = asyncio.Semaphore(concurrency)
print(f"Endpoint: {endpoint}")
print(f"Token: {token}")
print(f"Concurrency: {concurrency}")
logger.info("OCR plugin initialized")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}")
if __name__ == "__main__":
@ -167,4 +162,4 @@ if __name__ == "__main__":
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -13,10 +13,7 @@ import os
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
router = APIRouter(
tags=[PLUGIN_NAME],
responses={404: {"description": "Not found"}}
)
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
modelname = None
endpoint = None
@ -56,7 +53,11 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
response.raise_for_status()
result = response.json()
choices = result.get("choices", [])
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
if (
choices
and "message" in choices[0]
and "content" in choices[0]["message"]
):
return choices[0]["message"]["content"]
return ""
except Exception as e:
@ -70,19 +71,21 @@ async def predict(
img_base64 = image2base64(img_path)
if not img_base64:
return None
# Get the file extension
_, file_extension = os.path.splitext(img_path)
file_extension = file_extension.lower()[1:] # Remove the dot and convert to lowercase
file_extension = file_extension.lower()[
1:
] # Remove the dot and convert to lowercase
# Determine the MIME type
mime_types = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'webp': 'image/webp'
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"webp": "image/webp",
}
mime_type = mime_types.get(file_extension, 'image/jpeg')
mime_type = mime_types.get(file_extension, "image/jpeg")
request_data = {
"model": modelname,
@ -192,17 +195,30 @@ def init_plugin(config):
concurrency = config.concurrency
semaphore = asyncio.Semaphore(concurrency)
# Print the parameters
logger.info("VLM plugin initialized")
logger.info(f"Model Name: {modelname}")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}")
if __name__ == "__main__":
import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
parser.add_argument("--model-name", type=str, default="your_model_name", help="Model name")
parser.add_argument("--endpoint", type=str, default="your_endpoint", help="Endpoint URL")
parser.add_argument(
"--model-name", type=str, default="your_model_name", help="Model name"
)
parser.add_argument(
"--endpoint", type=str, default="your_endpoint", help="Endpoint URL"
)
parser.add_argument("--token", type=str, default="your_token", help="Access token")
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args()
@ -217,4 +233,4 @@ if __name__ == "__main__":
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)
uvicorn.run(app, host="0.0.0.0", port=args.port)