feat(vlm): force jpeg option

This commit is contained in:
arkohut 2024-09-03 18:35:41 +08:00
parent bc205eca11
commit e99792a974
3 changed files with 39 additions and 22 deletions

View File

@ -145,7 +145,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
updated_file_count = 0
added_file_count = 0
scanned_files = set()
semaphore = asyncio.Semaphore(4)
semaphore = asyncio.Semaphore(settings.batchsize)
async with httpx.AsyncClient(timeout=60) as client:
tasks = []
for root, _, files in os.walk(folder_path):
@ -608,7 +608,7 @@ def index(
pbar.refresh()
# Index each entity
batch_size = 8
batch_size = settings.batchsize
for i in range(0, len(entities), batch_size):
batch = entities[i : i + batch_size]
to_index = []

View File

@ -18,6 +18,7 @@ class VLMSettings(BaseModel):
endpoint: str = "http://localhost:11434"
token: str = ""
concurrency: int = 4
force_jpeg: bool = False # Add this line
class OCRSettings(BaseModel):
@ -64,6 +65,9 @@ class Settings(BaseSettings):
# Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings()
# New batchsize setting
batchsize: int = 4
@classmethod
def settings_customise_sources(
cls,

View File

@ -8,7 +8,7 @@ from memos.schemas import Entity, MetadataType
import logging
import uvicorn
import os
import io
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@ -20,6 +20,7 @@ endpoint = None
token = None
concurrency = None
semaphore = None
force_jpeg = None
# Configure logger
logging.basicConfig(level=logging.INFO)
@ -28,13 +29,24 @@ logger = logging.getLogger(__name__)
def image2base64(img_path):
try:
# Attempt to open and verify the image
with Image.open(img_path) as img:
img.verify() # Verify that it's a valid image file
img.verify()
# If verification passes, encode the image
with open(img_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
with Image.open(img_path) as img:
if force_jpeg:
# Convert image to RGB mode (removes alpha channel if present)
img = img.convert('RGB')
# Save as JPEG in memory
buffer = io.BytesIO()
img.save(buffer, format='JPEG')
buffer.seek(0)
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
else:
# Use original format
buffer = io.BytesIO()
img.save(buffer, format=img.format)
buffer.seek(0)
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
return encoded_string
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
@ -72,20 +84,19 @@ async def predict(
if not img_base64:
return None
# Get the file extension
_, file_extension = os.path.splitext(img_path)
file_extension = file_extension.lower()[
1:
] # Remove the dot and convert to lowercase
mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True
# Determine the MIME type
mime_types = {
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"webp": "image/webp",
}
mime_type = mime_types.get(file_extension, "image/jpeg")
if not force_jpeg:
# Only determine MIME type if not forcing JPEG
_, file_extension = os.path.splitext(img_path)
file_extension = file_extension.lower()[1:]
mime_types = {
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"webp": "image/webp",
}
mime_type = mime_types.get(file_extension, "image/jpeg")
request_data = {
"model": modelname,
@ -188,11 +199,12 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
modelname = config.modelname
endpoint = config.endpoint
token = config.token
concurrency = config.concurrency
force_jpeg = config.force_jpeg
semaphore = asyncio.Semaphore(concurrency)
# Print the parameters
@ -201,6 +213,7 @@ def init_plugin(config):
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}")
if __name__ == "__main__":