mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-10 13:07:15 +00:00
feat(ocr): force max image size 1920 x 1920
This commit is contained in:
parent
bd55696748
commit
7c41769516
@ -11,6 +11,9 @@ from rapidocr_onnxruntime import RapidOCR
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import yaml
|
import yaml
|
||||||
|
import io
|
||||||
|
|
||||||
|
MAX_THUMBNAIL_SIZE = (1920, 1920)
|
||||||
|
|
||||||
from fastapi import APIRouter, Request, HTTPException
|
from fastapi import APIRouter, Request, HTTPException
|
||||||
from memos.schemas import Entity, MetadataType
|
from memos.schemas import Entity, MetadataType
|
||||||
@ -35,9 +38,11 @@ logger = logging.getLogger(__name__)
|
|||||||
def image2base64(img_path):
|
def image2base64(img_path):
|
||||||
try:
|
try:
|
||||||
with Image.open(img_path) as img:
|
with Image.open(img_path) as img:
|
||||||
img.convert("RGB") # Check if image is not broken
|
img = img.convert("RGB")
|
||||||
with open(img_path, "rb") as image_file:
|
img.thumbnail(MAX_THUMBNAIL_SIZE)
|
||||||
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
buffered = io.BytesIO()
|
||||||
|
img.save(buffered, format="JPEG")
|
||||||
|
encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
return encoded_string
|
return encoded_string
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||||
@ -71,6 +76,8 @@ def convert_ocr_results(results):
|
|||||||
def predict_local(img_path):
|
def predict_local(img_path):
|
||||||
try:
|
try:
|
||||||
with Image.open(img_path) as img:
|
with Image.open(img_path) as img:
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.thumbnail(MAX_THUMBNAIL_SIZE)
|
||||||
img_array = np.array(img)
|
img_array = np.array(img)
|
||||||
results, _ = ocr(img_array)
|
results, _ = ocr(img_array)
|
||||||
return convert_ocr_results(results)
|
return convert_ocr_results(results)
|
||||||
@ -128,7 +135,6 @@ async def ocr(entity: Entity, request: Request):
|
|||||||
patch_url = f"{location_url}/metadata"
|
patch_url = f"{location_url}/metadata"
|
||||||
|
|
||||||
ocr_result = await predict(entity.filepath)
|
ocr_result = await predict(entity.filepath)
|
||||||
|
|
||||||
logger.info(ocr_result)
|
logger.info(ocr_result)
|
||||||
if not ocr_result:
|
if not ocr_result:
|
||||||
logger.info(f"No OCR result found for file: {entity.filepath}")
|
logger.info(f"No OCR result found for file: {entity.filepath}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user