mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-07 03:35:24 +00:00
feat(plugins): add ocr as build plugin
This commit is contained in:
parent
ec7ba1f989
commit
67a5e10d3e
@ -7,13 +7,18 @@ class VLMSettings(BaseModel):
|
||||
enabled: bool = False
|
||||
modelname: str = "internvl-1.5"
|
||||
endpoint: str = "http://localhost:11434"
|
||||
|
||||
class OCRSettings(BaseModel):
|
||||
enabled: bool = True
|
||||
endpoint: str = "http://localhost:5555/predict"
|
||||
token: str = ""
|
||||
concurrency: int = 8
|
||||
concurrency: int = 4
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
yaml_file=str(Path.home() / ".memos" / "config.yaml"),
|
||||
yaml_file_encoding="utf-8"
|
||||
yaml_file_encoding="utf-8",
|
||||
)
|
||||
|
||||
base_dir: str = str(Path.home() / ".memos")
|
||||
@ -28,6 +33,10 @@ class Settings(BaseSettings):
|
||||
# VLM plugin settings
|
||||
vlm: VLMSettings = VLMSettings()
|
||||
|
||||
# OCR plugin settings
|
||||
ocr: OCRSettings = OCRSettings()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Define the default database path
|
||||
@ -38,4 +47,4 @@ TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
|
||||
|
||||
# Function to get the database path from environment variable or default
|
||||
def get_database_path():
|
||||
return settings.database_path
|
||||
return settings.database_path
|
||||
|
@ -148,6 +148,25 @@ class LibraryPluginModel(Base):
|
||||
)
|
||||
|
||||
|
||||
def initialize_default_plugins(session):
|
||||
default_plugins = [
|
||||
PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"),
|
||||
PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"),
|
||||
]
|
||||
|
||||
for plugin in default_plugins:
|
||||
existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first()
|
||||
if not existing_plugin:
|
||||
session.add(plugin)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Create the database engine with the path from config
|
||||
engine = create_engine(f"sqlite:///{get_database_path()}")
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
# Initialize default plugins
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
initialize_default_plugins(session)
|
49
memos/plugins/ocr/README.md
Normal file
49
memos/plugins/ocr/README.md
Normal file
@ -0,0 +1,49 @@
|
||||
# OCR Plugin
|
||||
|
||||
This is a README file for the OCR plugin. This plugin uses the `RapidOCR` library to perform OCR (Optical Character Recognition) on image files and updates the metadata of the entity with the OCR results.
|
||||
|
||||
## How to Run
|
||||
|
||||
To run this OCR plugin, follow the steps below:
|
||||
|
||||
1. **Install the required dependencies:**
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **Run the FastAPI application:**
|
||||
|
||||
You can run the FastAPI application using `uvicorn`. Make sure you are in the directory where `main.py` is located.
|
||||
|
||||
```bash
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
3. **Integration with memos:**
|
||||
|
||||
```sh
|
||||
$ python -m memos.commands plugin create ocr http://localhost:8000
|
||||
Plugin created successfully
|
||||
```
|
||||
|
||||
```sh
|
||||
$ python -m memos.commands plugin ls
|
||||
|
||||
ID Name Description Webhook URL
|
||||
1 ocr http://localhost:8000/
|
||||
```
|
||||
|
||||
```sh
|
||||
$ python -m memos.commands plugin bind --lib 1 --plugin 1
|
||||
Plugin bound to library successfully
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
- `GET /`: Health check endpoint. Returns `{"healthy": True}` if the service is running.
|
||||
- `POST /`: OCR endpoint. Accepts an `Entity` object and a `Location` header. Performs OCR on the image file and updates the entity's metadata with the OCR results.
|
||||
|
||||
## Metadata
|
||||
|
||||
The OCR results are stored in the metadata field named `ocr_result` with the following structure:
|
0
memos/plugins/ocr/__init__.py
Normal file
0
memos/plugins/ocr/__init__.py
Normal file
@ -8,17 +8,20 @@ import io
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
||||
from memos.schemas import Entity, MetadataType
|
||||
|
||||
METADATA_FIELD_NAME = "ocr_result"
|
||||
PLUGIN_NAME = "ocr"
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
router = APIRouter(
|
||||
tags=[PLUGIN_NAME],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
endpoint = None
|
||||
token = None
|
||||
semaphore = asyncio.Semaphore(4)
|
||||
concurrency = None
|
||||
semaphore = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -50,6 +53,7 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N
|
||||
return response.json()
|
||||
|
||||
|
||||
# Modify the predict function to use semaphore
|
||||
async def predict(img_path):
|
||||
image_base64 = image2base64(img_path)
|
||||
if not image_base64:
|
||||
@ -59,19 +63,18 @@ async def predict(img_path):
|
||||
headers = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
ocr_result = await fetch(endpoint, client, image_base64, headers=headers)
|
||||
async with semaphore:
|
||||
ocr_result = await fetch(endpoint, client, image_base64, headers=headers)
|
||||
return ocr_result
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@router.get("/")
|
||||
async def read_root():
|
||||
return {"healthy": True}
|
||||
|
||||
|
||||
@app.post("/")
|
||||
@router.post("", include_in_schema=False)
|
||||
@router.post("/")
|
||||
async def ocr(entity: Entity, request: Request):
|
||||
if not entity.file_type_group == "image":
|
||||
return {METADATA_FIELD_NAME: "{}"}
|
||||
@ -123,27 +126,45 @@ async def ocr(entity: Entity, request: Request):
|
||||
}
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global endpoint, token, concurrency, semaphore
|
||||
endpoint = config.endpoint
|
||||
token = config.token
|
||||
concurrency = config.concurrency
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
print(f"Endpoint: {endpoint}")
|
||||
print(f"Token: {token}")
|
||||
print(f"Concurrency: {concurrency}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
import argparse
|
||||
from fastapi import FastAPI
|
||||
|
||||
parser = argparse.ArgumentParser(description="OCR Plugin")
|
||||
parser.add_argument(
|
||||
"--endpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
default="http://localhost:8080",
|
||||
help="The endpoint URL for the OCR service",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token", type=str, required=False, help="The token for authentication"
|
||||
"--token", type=str, default="", help="The token for authentication"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrency", type=int, default=4, help="The concurrency level"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="The port number to run the server on"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
endpoint = args.endpoint
|
||||
token = args.token
|
||||
port = args.port
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
init_plugin(args)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
@ -88,6 +88,18 @@ app.mount(
|
||||
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
|
||||
)
|
||||
|
||||
# Add VLM plugin router
|
||||
if settings.vlm.enabled:
|
||||
print("VLM plugin is enabled")
|
||||
vlm_main.init_plugin(settings.vlm)
|
||||
app.include_router(vlm_main.router, prefix="/plugins/vlm")
|
||||
|
||||
# Add OCR plugin router
|
||||
if settings.ocr.enabled:
|
||||
print("OCR plugin is enabled")
|
||||
ocr_main.init_plugin(settings.ocr)
|
||||
app.include_router(ocr_main.router, prefix="/plugins/ocr")
|
||||
|
||||
|
||||
@app.get("/favicon.png", response_class=FileResponse)
|
||||
async def favicon_png():
|
||||
@ -178,8 +190,12 @@ async def trigger_webhooks(
|
||||
location = str(
|
||||
request.url_for("get_entity_by_id", entity_id=entity.id)
|
||||
)
|
||||
webhook_url = plugin.webhook_url
|
||||
if webhook_url.startswith("/"):
|
||||
webhook_url = str(request.base_url)[:-1] + webhook_url
|
||||
print(f"webhook_url: {webhook_url}")
|
||||
task = client.post(
|
||||
plugin.webhook_url,
|
||||
webhook_url,
|
||||
json=entity.model_dump(mode="json"),
|
||||
headers={"Location": location},
|
||||
timeout=60.0,
|
||||
@ -683,19 +699,6 @@ async def get_file(file_path: str):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
|
||||
# Add VLM plugin router
|
||||
if settings.vlm.enabled:
|
||||
print("VLM plugin is enabled")
|
||||
vlm_main.init_plugin(settings.vlm)
|
||||
app.include_router(vlm_main.router, prefix=f"/plugins/{vlm_main.PLUGIN_NAME}")
|
||||
|
||||
# Add OCR plugin router
|
||||
if settings.ocr.enabled:
|
||||
print("OCR plugin is enabled")
|
||||
ocr_main.init_plugin(settings.ocr)
|
||||
app.include_router(ocr_main.router, prefix=f"/plugins/{ocr_main.PLUGIN_NAME}")
|
||||
|
||||
|
||||
def run_server():
|
||||
print("Database path:", get_database_path())
|
||||
print(
|
||||
|
Loading…
x
Reference in New Issue
Block a user