mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-09 12:37:12 +00:00
feat(plugins): add ocr as build plugin
This commit is contained in:
parent
ec7ba1f989
commit
67a5e10d3e
@ -7,13 +7,18 @@ class VLMSettings(BaseModel):
|
|||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
modelname: str = "internvl-1.5"
|
modelname: str = "internvl-1.5"
|
||||||
endpoint: str = "http://localhost:11434"
|
endpoint: str = "http://localhost:11434"
|
||||||
|
|
||||||
|
class OCRSettings(BaseModel):
|
||||||
|
enabled: bool = True
|
||||||
|
endpoint: str = "http://localhost:5555/predict"
|
||||||
token: str = ""
|
token: str = ""
|
||||||
concurrency: int = 8
|
concurrency: int = 4
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
yaml_file=str(Path.home() / ".memos" / "config.yaml"),
|
yaml_file=str(Path.home() / ".memos" / "config.yaml"),
|
||||||
yaml_file_encoding="utf-8"
|
yaml_file_encoding="utf-8",
|
||||||
)
|
)
|
||||||
|
|
||||||
base_dir: str = str(Path.home() / ".memos")
|
base_dir: str = str(Path.home() / ".memos")
|
||||||
@ -28,6 +33,10 @@ class Settings(BaseSettings):
|
|||||||
# VLM plugin settings
|
# VLM plugin settings
|
||||||
vlm: VLMSettings = VLMSettings()
|
vlm: VLMSettings = VLMSettings()
|
||||||
|
|
||||||
|
# OCR plugin settings
|
||||||
|
ocr: OCRSettings = OCRSettings()
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
# Define the default database path
|
# Define the default database path
|
||||||
@ -38,4 +47,4 @@ TYPESENSE_COLLECTION_NAME = settings.typesense_collection_name
|
|||||||
|
|
||||||
# Function to get the database path from environment variable or default
|
# Function to get the database path from environment variable or default
|
||||||
def get_database_path():
|
def get_database_path():
|
||||||
return settings.database_path
|
return settings.database_path
|
||||||
|
@ -148,6 +148,25 @@ class LibraryPluginModel(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_default_plugins(session):
|
||||||
|
default_plugins = [
|
||||||
|
PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"),
|
||||||
|
PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for plugin in default_plugins:
|
||||||
|
existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first()
|
||||||
|
if not existing_plugin:
|
||||||
|
session.add(plugin)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
# Create the database engine with the path from config
|
# Create the database engine with the path from config
|
||||||
engine = create_engine(f"sqlite:///{get_database_path()}")
|
engine = create_engine(f"sqlite:///{get_database_path()}")
|
||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
|
|
||||||
|
# Initialize default plugins
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
Session = sessionmaker(bind=engine)
|
||||||
|
with Session() as session:
|
||||||
|
initialize_default_plugins(session)
|
49
memos/plugins/ocr/README.md
Normal file
49
memos/plugins/ocr/README.md
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# OCR Plugin
|
||||||
|
|
||||||
|
This is a README file for the OCR plugin. This plugin uses the `RapidOCR` library to perform OCR (Optical Character Recognition) on image files and updates the metadata of the entity with the OCR results.
|
||||||
|
|
||||||
|
## How to Run
|
||||||
|
|
||||||
|
To run this OCR plugin, follow the steps below:
|
||||||
|
|
||||||
|
1. **Install the required dependencies:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Run the FastAPI application:**
|
||||||
|
|
||||||
|
You can run the FastAPI application using `uvicorn`. Make sure you are in the directory where `main.py` is located.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn main:app --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Integration with memos:**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ python -m memos.commands plugin create ocr http://localhost:8000
|
||||||
|
Plugin created successfully
|
||||||
|
```
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ python -m memos.commands plugin ls
|
||||||
|
|
||||||
|
ID Name Description Webhook URL
|
||||||
|
1 ocr http://localhost:8000/
|
||||||
|
```
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ python -m memos.commands plugin bind --lib 1 --plugin 1
|
||||||
|
Plugin bound to library successfully
|
||||||
|
```
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
|
||||||
|
- `GET /`: Health check endpoint. Returns `{"healthy": True}` if the service is running.
|
||||||
|
- `POST /`: OCR endpoint. Accepts an `Entity` object and a `Location` header. Performs OCR on the image file and updates the entity's metadata with the OCR results.
|
||||||
|
|
||||||
|
## Metadata
|
||||||
|
|
||||||
|
The OCR results are stored in the metadata field named `ocr_result` with the following structure:
|
0
memos/plugins/ocr/__init__.py
Normal file
0
memos/plugins/ocr/__init__.py
Normal file
@ -8,17 +8,20 @@ import io
|
|||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
||||||
from memos.schemas import Entity, MetadataType
|
from memos.schemas import Entity, MetadataType
|
||||||
|
|
||||||
METADATA_FIELD_NAME = "ocr_result"
|
METADATA_FIELD_NAME = "ocr_result"
|
||||||
PLUGIN_NAME = "ocr"
|
PLUGIN_NAME = "ocr"
|
||||||
|
|
||||||
app = FastAPI()
|
router = APIRouter(
|
||||||
|
tags=[PLUGIN_NAME],
|
||||||
|
responses={404: {"description": "Not found"}}
|
||||||
|
)
|
||||||
endpoint = None
|
endpoint = None
|
||||||
token = None
|
token = None
|
||||||
semaphore = asyncio.Semaphore(4)
|
concurrency = None
|
||||||
|
semaphore = None
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -50,6 +53,7 @@ async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = N
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
# Modify the predict function to use semaphore
|
||||||
async def predict(img_path):
|
async def predict(img_path):
|
||||||
image_base64 = image2base64(img_path)
|
image_base64 = image2base64(img_path)
|
||||||
if not image_base64:
|
if not image_base64:
|
||||||
@ -59,19 +63,18 @@ async def predict(img_path):
|
|||||||
headers = {}
|
headers = {}
|
||||||
if token:
|
if token:
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
ocr_result = await fetch(endpoint, client, image_base64, headers=headers)
|
async with semaphore:
|
||||||
|
ocr_result = await fetch(endpoint, client, image_base64, headers=headers)
|
||||||
return ocr_result
|
return ocr_result
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
@router.get("/")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def read_root():
|
async def read_root():
|
||||||
return {"healthy": True}
|
return {"healthy": True}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
@router.post("", include_in_schema=False)
|
||||||
|
@router.post("/")
|
||||||
async def ocr(entity: Entity, request: Request):
|
async def ocr(entity: Entity, request: Request):
|
||||||
if not entity.file_type_group == "image":
|
if not entity.file_type_group == "image":
|
||||||
return {METADATA_FIELD_NAME: "{}"}
|
return {METADATA_FIELD_NAME: "{}"}
|
||||||
@ -123,27 +126,45 @@ async def ocr(entity: Entity, request: Request):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def init_plugin(config):
|
||||||
|
global endpoint, token, concurrency, semaphore
|
||||||
|
endpoint = config.endpoint
|
||||||
|
token = config.token
|
||||||
|
concurrency = config.concurrency
|
||||||
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
print(f"Endpoint: {endpoint}")
|
||||||
|
print(f"Token: {token}")
|
||||||
|
print(f"Concurrency: {concurrency}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import argparse
|
import argparse
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="OCR Plugin")
|
parser = argparse.ArgumentParser(description="OCR Plugin")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--endpoint",
|
"--endpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
default="http://localhost:8080",
|
||||||
help="The endpoint URL for the OCR service",
|
help="The endpoint URL for the OCR service",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--token", type=str, required=False, help="The token for authentication"
|
"--token", type=str, default="", help="The token for authentication"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--concurrency", type=int, default=4, help="The concurrency level"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=8000, help="The port number to run the server on"
|
"--port", type=int, default=8000, help="The port number to run the server on"
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
endpoint = args.endpoint
|
|
||||||
token = args.token
|
|
||||||
port = args.port
|
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
init_plugin(args)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
@ -88,6 +88,18 @@ app.mount(
|
|||||||
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
|
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add VLM plugin router
|
||||||
|
if settings.vlm.enabled:
|
||||||
|
print("VLM plugin is enabled")
|
||||||
|
vlm_main.init_plugin(settings.vlm)
|
||||||
|
app.include_router(vlm_main.router, prefix="/plugins/vlm")
|
||||||
|
|
||||||
|
# Add OCR plugin router
|
||||||
|
if settings.ocr.enabled:
|
||||||
|
print("OCR plugin is enabled")
|
||||||
|
ocr_main.init_plugin(settings.ocr)
|
||||||
|
app.include_router(ocr_main.router, prefix="/plugins/ocr")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/favicon.png", response_class=FileResponse)
|
@app.get("/favicon.png", response_class=FileResponse)
|
||||||
async def favicon_png():
|
async def favicon_png():
|
||||||
@ -178,8 +190,12 @@ async def trigger_webhooks(
|
|||||||
location = str(
|
location = str(
|
||||||
request.url_for("get_entity_by_id", entity_id=entity.id)
|
request.url_for("get_entity_by_id", entity_id=entity.id)
|
||||||
)
|
)
|
||||||
|
webhook_url = plugin.webhook_url
|
||||||
|
if webhook_url.startswith("/"):
|
||||||
|
webhook_url = str(request.base_url)[:-1] + webhook_url
|
||||||
|
print(f"webhook_url: {webhook_url}")
|
||||||
task = client.post(
|
task = client.post(
|
||||||
plugin.webhook_url,
|
webhook_url,
|
||||||
json=entity.model_dump(mode="json"),
|
json=entity.model_dump(mode="json"),
|
||||||
headers={"Location": location},
|
headers={"Location": location},
|
||||||
timeout=60.0,
|
timeout=60.0,
|
||||||
@ -683,19 +699,6 @@ async def get_file(file_path: str):
|
|||||||
raise HTTPException(status_code=404, detail="File not found")
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
|
||||||
# Add VLM plugin router
|
|
||||||
if settings.vlm.enabled:
|
|
||||||
print("VLM plugin is enabled")
|
|
||||||
vlm_main.init_plugin(settings.vlm)
|
|
||||||
app.include_router(vlm_main.router, prefix=f"/plugins/{vlm_main.PLUGIN_NAME}")
|
|
||||||
|
|
||||||
# Add OCR plugin router
|
|
||||||
if settings.ocr.enabled:
|
|
||||||
print("OCR plugin is enabled")
|
|
||||||
ocr_main.init_plugin(settings.ocr)
|
|
||||||
app.include_router(ocr_main.router, prefix=f"/plugins/{ocr_main.PLUGIN_NAME}")
|
|
||||||
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
print("Database path:", get_database_path())
|
print("Database path:", get_database_path())
|
||||||
print(
|
print(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user