diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index 2beed73..e84d7c6 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -9,6 +9,7 @@ import logging import uvicorn import os import io +import numpy as np PLUGIN_NAME = "vlm" @@ -31,21 +32,35 @@ logger = logging.getLogger(__name__) def image2base64(img_path): try: with Image.open(img_path) as img: - img.verify() + img.verify() # Verify the image file with Image.open(img_path) as img: + # Check image size and skip if it's too small + if img.width < 10 or img.height < 10: + logger.warning(f"Image is too small: {img.width}x{img.height}. Skipping processing.") + return None + + # Convert image to RGB mode (removes alpha channel if present) + img = img.convert("RGB") + + # Convert to numpy array and check shape + img_array = np.array(img) + logger.info(f"Image shape: {img_array.shape}") + + if img_array.shape[2] != 3: + logger.warning(f"Unexpected number of channels: {img_array.shape[2]}. Expected 3. Skipping processing.") + return None + if force_jpeg: - # Convert image to RGB mode (removes alpha channel if present) - img = img.convert("RGB") # Save as JPEG in memory buffer = io.BytesIO() img.save(buffer, format="JPEG") buffer.seek(0) encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8") else: - # Use original format + # Use original format, but ensure it's RGB buffer = io.BytesIO() - img.save(buffer, format=img.format) + img.save(buffer, format=img.format or "JPEG") buffer.seek(0) encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8") return encoded_string @@ -89,6 +104,7 @@ async def predict_remote( ) -> Optional[str]: img_base64 = image2base64(img_path) if not img_base64: + logger.warning(f"Skipping processing for file: {img_path} due to invalid or small image") return None mime_type = (