feat(plugins): add ocr as build plugin

This commit is contained in:
arkohut 2024-08-25 16:48:28 +08:00
parent ec7ba1f989
commit 67a5e10d3e
14 changed files with 134 additions and 33 deletions

View File

@ -7,13 +7,18 @@ class VLMSettings(BaseModel):
enabled: bool = False
modelname: str = "internvl-1.5"
endpoint: str = "http://localhost:11434"
class OCRSettings(BaseModel):
enabled: bool = True
endpoint: str = "http://localhost:5555/predict"
token: str = ""
concurrency: int = 8
concurrency: int = 4
class Settings(BaseSettings):
model_config = SettingsConfigDict(
yaml_file=str(Path.home() / ".memos" / "config.yaml"),
yaml_file_encoding="utf-8"
yaml_file_encoding="utf-8",
)
base_dir: str = str(Path.home() / ".memos")
@ -28,6 +33,10 @@ class Settings(BaseSettings):
# VLM plugin settings
vlm: VLMSettings = VLMSettings()
# OCR plugin settings
ocr: OCRSettings = OCRSettings()
settings = Settings()
# Define the default database path
@ -38,4 +47,4 @@ TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
# Function to get the database path from environment variable or default
def get_database_path():
return settings.database_path
return settings.database_path

View File

@ -148,6 +148,25 @@ class LibraryPluginModel(Base):
)
def initialize_default_plugins(session):
default_plugins = [
PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"),
PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"),
]
for plugin in default_plugins:
existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first()
if not existing_plugin:
session.add(plugin)
session.commit()
# Create the database engine with the path from config
engine = create_engine(f"sqlite:///{get_database_path()}")
Base.metadata.create_all(engine)
# Initialize default plugins
from sqlalchemy.orm import sessionmaker
Session = sessionmaker(bind=engine)
with Session() as session:
initialize_default_plugins(session)

View File

@ -0,0 +1,49 @@
# OCR Plugin
This is a README file for the OCR plugin. This plugin uses the `RapidOCR` library to perform OCR (Optical Character Recognition) on image files and updates the metadata of the entity with the OCR results.
## How to Run
To run this OCR plugin, follow the steps below:
1. **Install the required dependencies:**
```bash
pip install -r requirements.txt
```
2. **Run the FastAPI application:**
You can run the FastAPI application using `uvicorn`. Make sure you are in the directory where `main.py` is located.
```bash
uvicorn main:app --host 0.0.0.0 --port 8000
```
3. **Integration with memos:**
```sh
$ python -m memos.commands plugin create ocr http://localhost:8000
Plugin created successfully
```
```sh
$ python -m memos.commands plugin ls
ID Name Description Webhook URL
1 ocr http://localhost:8000/
```
```sh
$ python -m memos.commands plugin bind --lib 1 --plugin 1
Plugin bound to library successfully
```
## Endpoints
- `GET /`: Health check endpoint. Returns `{"healthy": True}` if the service is running.
- `POST /`: OCR endpoint. Accepts an `Entity` object and a `Location` header. Performs OCR on the image file and updates the entity's metadata with the OCR results.
## Metadata
The OCR results are stored in the metadata field named `ocr_result` with the following structure:

View File

View File

@ -8,17 +8,20 @@ import io
import os
from PIL import Image
from fastapi import FastAPI, Request, HTTPException
from fastapi import APIRouter, FastAPI, Request, HTTPException
from memos.schemas import Entity, MetadataType
METADATA_FIELD_NAME = "ocr_result"
PLUGIN_NAME = "ocr"
app = FastAPI()
router = APIRouter(
tags=[PLUGIN_NAME],
responses={404: {"description": "Not found"}}
)
endpoint = None
token = None
semaphore = asyncio.Semaphore(4)
concurrency = None
semaphore = None
# Configure logger
logging.basicConfig(level=logging.INFO)
@ -50,6 +53,7 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N
return response.json()
# Modify the predict function to use semaphore
async def predict(img_path):
image_base64 = image2base64(img_path)
if not image_base64:
@ -59,19 +63,18 @@ async def predict(img_path):
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
ocr_result = await fetch(endpoint, client, image_base64, headers=headers)
async with semaphore:
ocr_result = await fetch(endpoint, client, image_base64, headers=headers)
return ocr_result
app = FastAPI()
@app.get("/")
@router.get("/")
async def read_root():
return {"healthy": True}
@app.post("/")
@router.post("", include_in_schema=False)
@router.post("/")
async def ocr(entity: Entity, request: Request):
if not entity.file_type_group == "image":
return {METADATA_FIELD_NAME: "{}"}
@ -123,27 +126,45 @@ async def ocr(entity: Entity, request: Request):
}
def init_plugin(config):
global endpoint, token, concurrency, semaphore
endpoint = config.endpoint
token = config.token
concurrency = config.concurrency
semaphore = asyncio.Semaphore(concurrency)
print(f"Endpoint: {endpoint}")
print(f"Token: {token}")
print(f"Concurrency: {concurrency}")
if __name__ == "__main__":
import uvicorn
import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="OCR Plugin")
parser.add_argument(
"--endpoint",
type=str,
required=True,
default="http://localhost:8080",
help="The endpoint URL for the OCR service",
)
parser.add_argument(
"--token", type=str, required=False, help="The token for authentication"
"--token", type=str, default="", help="The token for authentication"
)
parser.add_argument(
"--concurrency", type=int, default=4, help="The concurrency level"
)
parser.add_argument(
"--port", type=int, default=8000, help="The port number to run the server on"
)
args = parser.parse_args()
endpoint = args.endpoint
token = args.token
port = args.port
uvicorn.run(app, host="0.0.0.0", port=port)
init_plugin(args)
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -88,6 +88,18 @@ app.mount(
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
)
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix="/plugins/vlm")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix="/plugins/ocr")
@app.get("/favicon.png", response_class=FileResponse)
async def favicon_png():
@ -178,8 +190,12 @@ async def trigger_webhooks(
location = str(
request.url_for("get_entity_by_id", entity_id=entity.id)
)
webhook_url = plugin.webhook_url
if webhook_url.startswith("/"):
webhook_url = str(request.base_url)[:-1] + webhook_url
print(f"webhook_url: {webhook_url}")
task = client.post(
plugin.webhook_url,
webhook_url,
json=entity.model_dump(mode="json"),
headers={"Location": location},
timeout=60.0,
@ -683,19 +699,6 @@ async def get_file(file_path: str):
raise HTTPException(status_code=404, detail="File not found")
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix=f"/plugins/{vlm_main.PLUGIN_NAME}")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix=f"/plugins/{ocr_main.PLUGIN_NAME}")
def run_server():
print("Database path:", get_database_path())
print(