mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-07 03:35:24 +00:00
feat: enable config.yaml
This commit is contained in:
parent
1a08a44a4d
commit
d3b45ad197
@ -1,7 +1,15 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import Tuple, Type
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class VLMSettings(BaseModel):
|
||||
@ -42,7 +50,7 @@ class Settings(BaseSettings):
|
||||
typesense_collection_name: str = "entities"
|
||||
|
||||
# Server settings
|
||||
server_host: str = "0.0.0.0" # Add this line
|
||||
server_host: str = "0.0.0.0"
|
||||
server_port: int = 8080
|
||||
|
||||
# VLM plugin settings
|
||||
@ -54,6 +62,42 @@ class Settings(BaseSettings):
|
||||
# Embedding settings
|
||||
embedding: EmbeddingSettings = EmbeddingSettings()
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: Type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (env_settings, YamlConfigSettingsSource(settings_cls),)
|
||||
|
||||
|
||||
def dict_representer(dumper, data):
|
||||
return dumper.represent_dict(data.items())
|
||||
|
||||
|
||||
yaml.add_representer(OrderedDict, dict_representer)
|
||||
|
||||
|
||||
def create_default_config():
|
||||
config_path = Path.home() / ".memos" / "config.yaml"
|
||||
if not config_path.exists():
|
||||
settings = Settings()
|
||||
os.makedirs(config_path.parent, exist_ok=True)
|
||||
with open(config_path, "w") as f:
|
||||
# Convert settings to a dictionary and ensure order
|
||||
settings_dict = settings.model_dump()
|
||||
ordered_settings = OrderedDict(
|
||||
(key, settings_dict[key]) for key in settings.model_fields.keys()
|
||||
)
|
||||
yaml.dump(ordered_settings, f, Dumper=yaml.Dumper)
|
||||
|
||||
|
||||
# Create default config if it doesn't exist
|
||||
create_default_config()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@ -63,6 +107,7 @@ os.makedirs(settings.base_dir, exist_ok=True)
|
||||
# Global variable for Typesense collection name
|
||||
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
|
||||
|
||||
|
||||
# Function to get the database path from environment variable or default
|
||||
def get_database_path():
|
||||
return settings.database_path
|
||||
return settings.database_path
|
||||
|
@ -4,8 +4,6 @@ from typing import Optional
|
||||
import httpx
|
||||
import json
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
||||
@ -14,10 +12,7 @@ from memos.schemas import Entity, MetadataType
|
||||
METADATA_FIELD_NAME = "ocr_result"
|
||||
PLUGIN_NAME = "ocr"
|
||||
|
||||
router = APIRouter(
|
||||
tags=[PLUGIN_NAME],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
||||
endpoint = None
|
||||
token = None
|
||||
concurrency = None
|
||||
@ -88,9 +83,8 @@ async def ocr(entity: Entity, request: Request):
|
||||
|
||||
ocr_result = await predict(entity.filepath)
|
||||
|
||||
print(ocr_result)
|
||||
if ocr_result is None or not ocr_result:
|
||||
print(f"No OCR result found for file: {entity.filepath}")
|
||||
logger.info(f"No OCR result found for file: {entity.filepath}")
|
||||
return {METADATA_FIELD_NAME: "{}"}
|
||||
|
||||
# Call the URL to patch the entity's metadata
|
||||
@ -133,9 +127,10 @@ def init_plugin(config):
|
||||
concurrency = config.concurrency
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
print(f"Endpoint: {endpoint}")
|
||||
print(f"Token: {token}")
|
||||
print(f"Concurrency: {concurrency}")
|
||||
logger.info("OCR plugin initialized")
|
||||
logger.info(f"Endpoint: {endpoint}")
|
||||
logger.info(f"Token: {token}")
|
||||
logger.info(f"Concurrency: {concurrency}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -167,4 +162,4 @@ if __name__ == "__main__":
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
|
@ -13,10 +13,7 @@ import os
|
||||
PLUGIN_NAME = "vlm"
|
||||
PROMPT = "描述这张图片的内容"
|
||||
|
||||
router = APIRouter(
|
||||
tags=[PLUGIN_NAME],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
||||
|
||||
modelname = None
|
||||
endpoint = None
|
||||
@ -56,7 +53,11 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
choices = result.get("choices", [])
|
||||
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
|
||||
if (
|
||||
choices
|
||||
and "message" in choices[0]
|
||||
and "content" in choices[0]["message"]
|
||||
):
|
||||
return choices[0]["message"]["content"]
|
||||
return ""
|
||||
except Exception as e:
|
||||
@ -70,19 +71,21 @@ async def predict(
|
||||
img_base64 = image2base64(img_path)
|
||||
if not img_base64:
|
||||
return None
|
||||
|
||||
|
||||
# Get the file extension
|
||||
_, file_extension = os.path.splitext(img_path)
|
||||
file_extension = file_extension.lower()[1:] # Remove the dot and convert to lowercase
|
||||
file_extension = file_extension.lower()[
|
||||
1:
|
||||
] # Remove the dot and convert to lowercase
|
||||
|
||||
# Determine the MIME type
|
||||
mime_types = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'webp': 'image/webp'
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
mime_type = mime_types.get(file_extension, 'image/jpeg')
|
||||
mime_type = mime_types.get(file_extension, "image/jpeg")
|
||||
|
||||
request_data = {
|
||||
"model": modelname,
|
||||
@ -192,17 +195,30 @@ def init_plugin(config):
|
||||
concurrency = config.concurrency
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
# Print the parameters
|
||||
logger.info("VLM plugin initialized")
|
||||
logger.info(f"Model Name: {modelname}")
|
||||
logger.info(f"Endpoint: {endpoint}")
|
||||
logger.info(f"Token: {token}")
|
||||
logger.info(f"Concurrency: {concurrency}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from fastapi import FastAPI
|
||||
|
||||
parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
|
||||
parser.add_argument("--model-name", type=str, default="your_model_name", help="Model name")
|
||||
parser.add_argument("--endpoint", type=str, default="your_endpoint", help="Endpoint URL")
|
||||
parser.add_argument(
|
||||
"--model-name", type=str, default="your_model_name", help="Model name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--endpoint", type=str, default="your_endpoint", help="Endpoint URL"
|
||||
)
|
||||
parser.add_argument("--token", type=str, default="your_token", help="Access token")
|
||||
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -217,4 +233,4 @@ if __name__ == "__main__":
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
|
Loading…
x
Reference in New Issue
Block a user