mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-09 20:47:11 +00:00
feat: enable config.yaml
This commit is contained in:
parent
1a08a44a4d
commit
d3b45ad197
@ -1,7 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from typing import Tuple, Type
|
||||||
|
from pydantic_settings import (
|
||||||
|
BaseSettings,
|
||||||
|
PydanticBaseSettingsSource,
|
||||||
|
SettingsConfigDict,
|
||||||
|
YamlConfigSettingsSource,
|
||||||
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
import yaml
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class VLMSettings(BaseModel):
|
class VLMSettings(BaseModel):
|
||||||
@ -42,7 +50,7 @@ class Settings(BaseSettings):
|
|||||||
typesense_collection_name: str = "entities"
|
typesense_collection_name: str = "entities"
|
||||||
|
|
||||||
# Server settings
|
# Server settings
|
||||||
server_host: str = "0.0.0.0" # Add this line
|
server_host: str = "0.0.0.0"
|
||||||
server_port: int = 8080
|
server_port: int = 8080
|
||||||
|
|
||||||
# VLM plugin settings
|
# VLM plugin settings
|
||||||
@ -54,6 +62,42 @@ class Settings(BaseSettings):
|
|||||||
# Embedding settings
|
# Embedding settings
|
||||||
embedding: EmbeddingSettings = EmbeddingSettings()
|
embedding: EmbeddingSettings = EmbeddingSettings()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def settings_customise_sources(
|
||||||
|
cls,
|
||||||
|
settings_cls: Type[BaseSettings],
|
||||||
|
init_settings: PydanticBaseSettingsSource,
|
||||||
|
env_settings: PydanticBaseSettingsSource,
|
||||||
|
dotenv_settings: PydanticBaseSettingsSource,
|
||||||
|
file_secret_settings: PydanticBaseSettingsSource,
|
||||||
|
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||||
|
return (env_settings, YamlConfigSettingsSource(settings_cls),)
|
||||||
|
|
||||||
|
|
||||||
|
def dict_representer(dumper, data):
|
||||||
|
return dumper.represent_dict(data.items())
|
||||||
|
|
||||||
|
|
||||||
|
yaml.add_representer(OrderedDict, dict_representer)
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_config():
|
||||||
|
config_path = Path.home() / ".memos" / "config.yaml"
|
||||||
|
if not config_path.exists():
|
||||||
|
settings = Settings()
|
||||||
|
os.makedirs(config_path.parent, exist_ok=True)
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
# Convert settings to a dictionary and ensure order
|
||||||
|
settings_dict = settings.model_dump()
|
||||||
|
ordered_settings = OrderedDict(
|
||||||
|
(key, settings_dict[key]) for key in settings.model_fields.keys()
|
||||||
|
)
|
||||||
|
yaml.dump(ordered_settings, f, Dumper=yaml.Dumper)
|
||||||
|
|
||||||
|
|
||||||
|
# Create default config if it doesn't exist
|
||||||
|
create_default_config()
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
@ -63,6 +107,7 @@ os.makedirs(settings.base_dir, exist_ok=True)
|
|||||||
# Global variable for Typesense collection name
|
# Global variable for Typesense collection name
|
||||||
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
|
TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
|
||||||
|
|
||||||
|
|
||||||
# Function to get the database path from environment variable or default
|
# Function to get the database path from environment variable or default
|
||||||
def get_database_path():
|
def get_database_path():
|
||||||
return settings.database_path
|
return settings.database_path
|
||||||
|
@ -4,8 +4,6 @@ from typing import Optional
|
|||||||
import httpx
|
import httpx
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
import io
|
|
||||||
import os
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
||||||
@ -14,10 +12,7 @@ from memos.schemas import Entity, MetadataType
|
|||||||
METADATA_FIELD_NAME = "ocr_result"
|
METADATA_FIELD_NAME = "ocr_result"
|
||||||
PLUGIN_NAME = "ocr"
|
PLUGIN_NAME = "ocr"
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
||||||
tags=[PLUGIN_NAME],
|
|
||||||
responses={404: {"description": "Not found"}}
|
|
||||||
)
|
|
||||||
endpoint = None
|
endpoint = None
|
||||||
token = None
|
token = None
|
||||||
concurrency = None
|
concurrency = None
|
||||||
@ -88,9 +83,8 @@ async def ocr(entity: Entity, request: Request):
|
|||||||
|
|
||||||
ocr_result = await predict(entity.filepath)
|
ocr_result = await predict(entity.filepath)
|
||||||
|
|
||||||
print(ocr_result)
|
|
||||||
if ocr_result is None or not ocr_result:
|
if ocr_result is None or not ocr_result:
|
||||||
print(f"No OCR result found for file: {entity.filepath}")
|
logger.info(f"No OCR result found for file: {entity.filepath}")
|
||||||
return {METADATA_FIELD_NAME: "{}"}
|
return {METADATA_FIELD_NAME: "{}"}
|
||||||
|
|
||||||
# Call the URL to patch the entity's metadata
|
# Call the URL to patch the entity's metadata
|
||||||
@ -133,9 +127,10 @@ def init_plugin(config):
|
|||||||
concurrency = config.concurrency
|
concurrency = config.concurrency
|
||||||
semaphore = asyncio.Semaphore(concurrency)
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
print(f"Endpoint: {endpoint}")
|
logger.info("OCR plugin initialized")
|
||||||
print(f"Token: {token}")
|
logger.info(f"Endpoint: {endpoint}")
|
||||||
print(f"Concurrency: {concurrency}")
|
logger.info(f"Token: {token}")
|
||||||
|
logger.info(f"Concurrency: {concurrency}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -167,4 +162,4 @@ if __name__ == "__main__":
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||||
|
@ -13,10 +13,7 @@ import os
|
|||||||
PLUGIN_NAME = "vlm"
|
PLUGIN_NAME = "vlm"
|
||||||
PROMPT = "描述这张图片的内容"
|
PROMPT = "描述这张图片的内容"
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
||||||
tags=[PLUGIN_NAME],
|
|
||||||
responses={404: {"description": "Not found"}}
|
|
||||||
)
|
|
||||||
|
|
||||||
modelname = None
|
modelname = None
|
||||||
endpoint = None
|
endpoint = None
|
||||||
@ -56,7 +53,11 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
choices = result.get("choices", [])
|
choices = result.get("choices", [])
|
||||||
if choices and "message" in choices[0] and "content" in choices[0]["message"]:
|
if (
|
||||||
|
choices
|
||||||
|
and "message" in choices[0]
|
||||||
|
and "content" in choices[0]["message"]
|
||||||
|
):
|
||||||
return choices[0]["message"]["content"]
|
return choices[0]["message"]["content"]
|
||||||
return ""
|
return ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -70,19 +71,21 @@ async def predict(
|
|||||||
img_base64 = image2base64(img_path)
|
img_base64 = image2base64(img_path)
|
||||||
if not img_base64:
|
if not img_base64:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get the file extension
|
# Get the file extension
|
||||||
_, file_extension = os.path.splitext(img_path)
|
_, file_extension = os.path.splitext(img_path)
|
||||||
file_extension = file_extension.lower()[1:] # Remove the dot and convert to lowercase
|
file_extension = file_extension.lower()[
|
||||||
|
1:
|
||||||
|
] # Remove the dot and convert to lowercase
|
||||||
|
|
||||||
# Determine the MIME type
|
# Determine the MIME type
|
||||||
mime_types = {
|
mime_types = {
|
||||||
'png': 'image/png',
|
"png": "image/png",
|
||||||
'jpg': 'image/jpeg',
|
"jpg": "image/jpeg",
|
||||||
'jpeg': 'image/jpeg',
|
"jpeg": "image/jpeg",
|
||||||
'webp': 'image/webp'
|
"webp": "image/webp",
|
||||||
}
|
}
|
||||||
mime_type = mime_types.get(file_extension, 'image/jpeg')
|
mime_type = mime_types.get(file_extension, "image/jpeg")
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"model": modelname,
|
"model": modelname,
|
||||||
@ -192,17 +195,30 @@ def init_plugin(config):
|
|||||||
concurrency = config.concurrency
|
concurrency = config.concurrency
|
||||||
semaphore = asyncio.Semaphore(concurrency)
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
# Print the parameters
|
||||||
|
logger.info("VLM plugin initialized")
|
||||||
|
logger.info(f"Model Name: {modelname}")
|
||||||
|
logger.info(f"Endpoint: {endpoint}")
|
||||||
|
logger.info(f"Token: {token}")
|
||||||
|
logger.info(f"Concurrency: {concurrency}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
|
parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
|
||||||
parser.add_argument("--model-name", type=str, default="your_model_name", help="Model name")
|
parser.add_argument(
|
||||||
parser.add_argument("--endpoint", type=str, default="your_endpoint", help="Endpoint URL")
|
"--model-name", type=str, default="your_model_name", help="Model name"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--endpoint", type=str, default="your_endpoint", help="Endpoint URL"
|
||||||
|
)
|
||||||
parser.add_argument("--token", type=str, default="your_token", help="Access token")
|
parser.add_argument("--token", type=str, default="your_token", help="Access token")
|
||||||
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
|
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
|
||||||
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=8000, help="Port to run the server on"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -217,4 +233,4 @@ if __name__ == "__main__":
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user