feat(plugins): add ocr as build plugin

This commit is contained in:
arkohut 2024-08-25 16:48:28 +08:00
parent ec7ba1f989
commit 67a5e10d3e
14 changed files with 134 additions and 33 deletions

View File

@ -7,13 +7,18 @@ class VLMSettings(BaseModel):
enabled: bool = False enabled: bool = False
modelname: str = "internvl-1.5" modelname: str = "internvl-1.5"
endpoint: str = "http://localhost:11434" endpoint: str = "http://localhost:11434"
class OCRSettings(BaseModel):
enabled: bool = True
endpoint: str = "http://localhost:5555/predict"
token: str = "" token: str = ""
concurrency: int = 8 concurrency: int = 4
class Settings(BaseSettings): class Settings(BaseSettings):
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
yaml_file=str(Path.home() / ".memos" / "config.yaml"), yaml_file=str(Path.home() / ".memos" / "config.yaml"),
yaml_file_encoding="utf-8" yaml_file_encoding="utf-8",
) )
base_dir: str = str(Path.home() / ".memos") base_dir: str = str(Path.home() / ".memos")
@ -28,6 +33,10 @@ class Settings(BaseSettings):
# VLM plugin settings # VLM plugin settings
vlm: VLMSettings = VLMSettings() vlm: VLMSettings = VLMSettings()
# OCR plugin settings
ocr: OCRSettings = OCRSettings()
settings = Settings() settings = Settings()
# Define the default database path # Define the default database path
@ -38,4 +47,4 @@ TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
# Function to get the database path from environment variable or default # Function to get the database path from environment variable or default
def get_database_path(): def get_database_path():
return settings.database_path return settings.database_path

View File

@ -148,6 +148,25 @@ class LibraryPluginModel(Base):
) )
def initialize_default_plugins(session):
default_plugins = [
PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"),
PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"),
]
for plugin in default_plugins:
existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first()
if not existing_plugin:
session.add(plugin)
session.commit()
# Create the database engine with the path from config # Create the database engine with the path from config
engine = create_engine(f"sqlite:///{get_database_path()}") engine = create_engine(f"sqlite:///{get_database_path()}")
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
# Initialize default plugins
from sqlalchemy.orm import sessionmaker
Session = sessionmaker(bind=engine)
with Session() as session:
initialize_default_plugins(session)

View File

@ -0,0 +1,49 @@
# OCR Plugin
This is a README file for the OCR plugin. This plugin uses the `RapidOCR` library to perform OCR (Optical Character Recognition) on image files and updates the metadata of the entity with the OCR results.
## How to Run
To run this OCR plugin, follow the steps below:
1. **Install the required dependencies:**
```bash
pip install -r requirements.txt
```
2. **Run the FastAPI application:**
You can run the FastAPI application using `uvicorn`. Make sure you are in the directory where `main.py` is located.
```bash
uvicorn main:app --host 0.0.0.0 --port 8000
```
3. **Integration with memos:**
```sh
$ python -m memos.commands plugin create ocr http://localhost:8000
Plugin created successfully
```
```sh
$ python -m memos.commands plugin ls
ID Name Description Webhook URL
1 ocr http://localhost:8000/
```
```sh
$ python -m memos.commands plugin bind --lib 1 --plugin 1
Plugin bound to library successfully
```
## Endpoints
- `GET /`: Health check endpoint. Returns `{"healthy": True}` if the service is running.
- `POST /`: OCR endpoint. Accepts an `Entity` object and a `Location` header. Performs OCR on the image file and updates the entity's metadata with the OCR results.
## Metadata
The OCR results are stored in the metadata field named `ocr_result` with the following structure:

View File

View File

@ -8,17 +8,20 @@ import io
import os import os
from PIL import Image from PIL import Image
from fastapi import FastAPI, Request, HTTPException from fastapi import APIRouter, FastAPI, Request, HTTPException
from memos.schemas import Entity, MetadataType from memos.schemas import Entity, MetadataType
METADATA_FIELD_NAME = "ocr_result" METADATA_FIELD_NAME = "ocr_result"
PLUGIN_NAME = "ocr" PLUGIN_NAME = "ocr"
app = FastAPI() router = APIRouter(
tags=[PLUGIN_NAME],
responses={404: {"description": "Not found"}}
)
endpoint = None endpoint = None
token = None token = None
semaphore = asyncio.Semaphore(4) concurrency = None
semaphore = None
# Configure logger # Configure logger
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -50,6 +53,7 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N
return response.json() return response.json()
# Modify the predict function to use semaphore
async def predict(img_path): async def predict(img_path):
image_base64 = image2base64(img_path) image_base64 = image2base64(img_path)
if not image_base64: if not image_base64:
@ -59,19 +63,18 @@ async def predict(img_path):
headers = {} headers = {}
if token: if token:
headers["Authorization"] = f"Bearer {token}" headers["Authorization"] = f"Bearer {token}"
ocr_result = await fetch(endpoint, client, image_base64, headers=headers) async with semaphore:
ocr_result = await fetch(endpoint, client, image_base64, headers=headers)
return ocr_result return ocr_result
app = FastAPI() @router.get("/")
@app.get("/")
async def read_root(): async def read_root():
return {"healthy": True} return {"healthy": True}
@app.post("/") @router.post("", include_in_schema=False)
@router.post("/")
async def ocr(entity: Entity, request: Request): async def ocr(entity: Entity, request: Request):
if not entity.file_type_group == "image": if not entity.file_type_group == "image":
return {METADATA_FIELD_NAME: "{}"} return {METADATA_FIELD_NAME: "{}"}
@ -123,27 +126,45 @@ async def ocr(entity: Entity, request: Request):
} }
def init_plugin(config):
global endpoint, token, concurrency, semaphore
endpoint = config.endpoint
token = config.token
concurrency = config.concurrency
semaphore = asyncio.Semaphore(concurrency)
print(f"Endpoint: {endpoint}")
print(f"Token: {token}")
print(f"Concurrency: {concurrency}")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
import argparse import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="OCR Plugin") parser = argparse.ArgumentParser(description="OCR Plugin")
parser.add_argument( parser.add_argument(
"--endpoint", "--endpoint",
type=str, type=str,
required=True, default="http://localhost:8080",
help="The endpoint URL for the OCR service", help="The endpoint URL for the OCR service",
) )
parser.add_argument( parser.add_argument(
"--token", type=str, required=False, help="The token for authentication" "--token", type=str, default="", help="The token for authentication"
)
parser.add_argument(
"--concurrency", type=int, default=4, help="The concurrency level"
) )
parser.add_argument( parser.add_argument(
"--port", type=int, default=8000, help="The port number to run the server on" "--port", type=int, default=8000, help="The port number to run the server on"
) )
args = parser.parse_args() args = parser.parse_args()
endpoint = args.endpoint
token = args.token
port = args.port
uvicorn.run(app, host="0.0.0.0", port=port) init_plugin(args)
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -88,6 +88,18 @@ app.mount(
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True) "/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
) )
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix="/plugins/vlm")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix="/plugins/ocr")
@app.get("/favicon.png", response_class=FileResponse) @app.get("/favicon.png", response_class=FileResponse)
async def favicon_png(): async def favicon_png():
@ -178,8 +190,12 @@ async def trigger_webhooks(
location = str( location = str(
request.url_for("get_entity_by_id", entity_id=entity.id) request.url_for("get_entity_by_id", entity_id=entity.id)
) )
webhook_url = plugin.webhook_url
if webhook_url.startswith("/"):
webhook_url = str(request.base_url)[:-1] + webhook_url
print(f"webhook_url: {webhook_url}")
task = client.post( task = client.post(
plugin.webhook_url, webhook_url,
json=entity.model_dump(mode="json"), json=entity.model_dump(mode="json"),
headers={"Location": location}, headers={"Location": location},
timeout=60.0, timeout=60.0,
@ -683,19 +699,6 @@ async def get_file(file_path: str):
raise HTTPException(status_code=404, detail="File not found") raise HTTPException(status_code=404, detail="File not found")
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix=f"/plugins/{vlm_main.PLUGIN_NAME}")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix=f"/plugins/{ocr_main.PLUGIN_NAME}")
def run_server(): def run_server():
print("Database path:", get_database_path()) print("Database path:", get_database_path())
print( print(