feat(vlm): force jpeg option

This commit is contained in:
arkohut 2024-09-03 18:35:41 +08:00
parent bc205eca11
commit e99792a974
3 changed files with 39 additions and 22 deletions

View File

@ -145,7 +145,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
updated_file_count = 0 updated_file_count = 0
added_file_count = 0 added_file_count = 0
scanned_files = set() scanned_files = set()
semaphore = asyncio.Semaphore(4) semaphore = asyncio.Semaphore(settings.batchsize)
async with httpx.AsyncClient(timeout=60) as client: async with httpx.AsyncClient(timeout=60) as client:
tasks = [] tasks = []
for root, _, files in os.walk(folder_path): for root, _, files in os.walk(folder_path):
@ -608,7 +608,7 @@ def index(
pbar.refresh() pbar.refresh()
# Index each entity # Index each entity
batch_size = 8 batch_size = settings.batchsize
for i in range(0, len(entities), batch_size): for i in range(0, len(entities), batch_size):
batch = entities[i : i + batch_size] batch = entities[i : i + batch_size]
to_index = [] to_index = []

View File

@ -18,6 +18,7 @@ class VLMSettings(BaseModel):
endpoint: str = "http://localhost:11434" endpoint: str = "http://localhost:11434"
token: str = "" token: str = ""
concurrency: int = 4 concurrency: int = 4
force_jpeg: bool = False # Add this line
class OCRSettings(BaseModel): class OCRSettings(BaseModel):
@ -64,6 +65,9 @@ class Settings(BaseSettings):
# Embedding settings # Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings() embedding: EmbeddingSettings = EmbeddingSettings()
# New batchsize setting
batchsize: int = 4
@classmethod @classmethod
def settings_customise_sources( def settings_customise_sources(
cls, cls,

View File

@ -8,7 +8,7 @@ from memos.schemas import Entity, MetadataType
import logging import logging
import uvicorn import uvicorn
import os import os
import io
PLUGIN_NAME = "vlm" PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容" PROMPT = "描述这张图片的内容"
@ -20,6 +20,7 @@ endpoint = None
token = None token = None
concurrency = None concurrency = None
semaphore = None semaphore = None
force_jpeg = None
# Configure logger # Configure logger
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -28,13 +29,24 @@ logger = logging.getLogger(__name__)
def image2base64(img_path): def image2base64(img_path):
try: try:
# Attempt to open and verify the image
with Image.open(img_path) as img: with Image.open(img_path) as img:
img.verify() # Verify that it's a valid image file img.verify()
# If verification passes, encode the image with Image.open(img_path) as img:
with open(img_path, "rb") as image_file: if force_jpeg:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8") # Convert image to RGB mode (removes alpha channel if present)
img = img.convert('RGB')
# Save as JPEG in memory
buffer = io.BytesIO()
img.save(buffer, format='JPEG')
buffer.seek(0)
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
else:
# Use original format
buffer = io.BytesIO()
img.save(buffer, format=img.format)
buffer.seek(0)
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
return encoded_string return encoded_string
except Exception as e: except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}") logger.error(f"Error processing image {img_path}: {str(e)}")
@ -72,20 +84,19 @@ async def predict(
if not img_base64: if not img_base64:
return None return None
# Get the file extension mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True
_, file_extension = os.path.splitext(img_path)
file_extension = file_extension.lower()[
1:
] # Remove the dot and convert to lowercase
# Determine the MIME type if not force_jpeg:
mime_types = { # Only determine MIME type if not forcing JPEG
"png": "image/png", _, file_extension = os.path.splitext(img_path)
"jpg": "image/jpeg", file_extension = file_extension.lower()[1:]
"jpeg": "image/jpeg", mime_types = {
"webp": "image/webp", "png": "image/png",
} "jpg": "image/jpeg",
mime_type = mime_types.get(file_extension, "image/jpeg") "jpeg": "image/jpeg",
"webp": "image/webp",
}
mime_type = mime_types.get(file_extension, "image/jpeg")
request_data = { request_data = {
"model": modelname, "model": modelname,
@ -188,11 +199,12 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config): def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore global modelname, endpoint, token, concurrency, semaphore, force_jpeg
modelname = config.modelname modelname = config.modelname
endpoint = config.endpoint endpoint = config.endpoint
token = config.token token = config.token
concurrency = config.concurrency concurrency = config.concurrency
force_jpeg = config.force_jpeg
semaphore = asyncio.Semaphore(concurrency) semaphore = asyncio.Semaphore(concurrency)
# Print the parameters # Print the parameters
@ -201,6 +213,7 @@ def init_plugin(config):
logger.info(f"Endpoint: {endpoint}") logger.info(f"Endpoint: {endpoint}")
logger.info(f"Token: {token}") logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}") logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}")
if __name__ == "__main__": if __name__ == "__main__":