mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
feat(vlm): force jpeg option
This commit is contained in:
parent
bc205eca11
commit
e99792a974
@ -145,7 +145,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
|
||||
updated_file_count = 0
|
||||
added_file_count = 0
|
||||
scanned_files = set()
|
||||
semaphore = asyncio.Semaphore(4)
|
||||
semaphore = asyncio.Semaphore(settings.batchsize)
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
tasks = []
|
||||
for root, _, files in os.walk(folder_path):
|
||||
@ -608,7 +608,7 @@ def index(
|
||||
pbar.refresh()
|
||||
|
||||
# Index each entity
|
||||
batch_size = 8
|
||||
batch_size = settings.batchsize
|
||||
for i in range(0, len(entities), batch_size):
|
||||
batch = entities[i : i + batch_size]
|
||||
to_index = []
|
||||
|
@ -18,6 +18,7 @@ class VLMSettings(BaseModel):
|
||||
endpoint: str = "http://localhost:11434"
|
||||
token: str = ""
|
||||
concurrency: int = 4
|
||||
force_jpeg: bool = False # Add this line
|
||||
|
||||
|
||||
class OCRSettings(BaseModel):
|
||||
@ -64,6 +65,9 @@ class Settings(BaseSettings):
|
||||
# Embedding settings
|
||||
embedding: EmbeddingSettings = EmbeddingSettings()
|
||||
|
||||
# New batchsize setting
|
||||
batchsize: int = 4
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
|
@ -8,7 +8,7 @@ from memos.schemas import Entity, MetadataType
|
||||
import logging
|
||||
import uvicorn
|
||||
import os
|
||||
|
||||
import io
|
||||
|
||||
PLUGIN_NAME = "vlm"
|
||||
PROMPT = "描述这张图片的内容"
|
||||
@ -20,6 +20,7 @@ endpoint = None
|
||||
token = None
|
||||
concurrency = None
|
||||
semaphore = None
|
||||
force_jpeg = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -28,13 +29,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def image2base64(img_path):
|
||||
try:
|
||||
# Attempt to open and verify the image
|
||||
with Image.open(img_path) as img:
|
||||
img.verify() # Verify that it's a valid image file
|
||||
img.verify()
|
||||
|
||||
# If verification passes, encode the image
|
||||
with open(img_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
with Image.open(img_path) as img:
|
||||
if force_jpeg:
|
||||
# Convert image to RGB mode (removes alpha channel if present)
|
||||
img = img.convert('RGB')
|
||||
# Save as JPEG in memory
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG')
|
||||
buffer.seek(0)
|
||||
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
else:
|
||||
# Use original format
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format=img.format)
|
||||
buffer.seek(0)
|
||||
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
return encoded_string
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||
@ -72,20 +84,19 @@ async def predict(
|
||||
if not img_base64:
|
||||
return None
|
||||
|
||||
# Get the file extension
|
||||
_, file_extension = os.path.splitext(img_path)
|
||||
file_extension = file_extension.lower()[
|
||||
1:
|
||||
] # Remove the dot and convert to lowercase
|
||||
mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True
|
||||
|
||||
# Determine the MIME type
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
mime_type = mime_types.get(file_extension, "image/jpeg")
|
||||
if not force_jpeg:
|
||||
# Only determine MIME type if not forcing JPEG
|
||||
_, file_extension = os.path.splitext(img_path)
|
||||
file_extension = file_extension.lower()[1:]
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
mime_type = mime_types.get(file_extension, "image/jpeg")
|
||||
|
||||
request_data = {
|
||||
"model": modelname,
|
||||
@ -188,11 +199,12 @@ async def vlm(entity: Entity, request: Request):
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global modelname, endpoint, token, concurrency, semaphore
|
||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
|
||||
modelname = config.modelname
|
||||
endpoint = config.endpoint
|
||||
token = config.token
|
||||
concurrency = config.concurrency
|
||||
force_jpeg = config.force_jpeg
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
# Print the parameters
|
||||
@ -201,6 +213,7 @@ def init_plugin(config):
|
||||
logger.info(f"Endpoint: {endpoint}")
|
||||
logger.info(f"Token: {token}")
|
||||
logger.info(f"Concurrency: {concurrency}")
|
||||
logger.info(f"Force JPEG: {force_jpeg}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user