mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-10 21:17:14 +00:00
feat(vlm): force jpeg option
This commit is contained in:
parent
bc205eca11
commit
e99792a974
@ -145,7 +145,7 @@ async def loop_files(library_id, folder, folder_path, force, plugins):
|
|||||||
updated_file_count = 0
|
updated_file_count = 0
|
||||||
added_file_count = 0
|
added_file_count = 0
|
||||||
scanned_files = set()
|
scanned_files = set()
|
||||||
semaphore = asyncio.Semaphore(4)
|
semaphore = asyncio.Semaphore(settings.batchsize)
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
tasks = []
|
tasks = []
|
||||||
for root, _, files in os.walk(folder_path):
|
for root, _, files in os.walk(folder_path):
|
||||||
@ -608,7 +608,7 @@ def index(
|
|||||||
pbar.refresh()
|
pbar.refresh()
|
||||||
|
|
||||||
# Index each entity
|
# Index each entity
|
||||||
batch_size = 8
|
batch_size = settings.batchsize
|
||||||
for i in range(0, len(entities), batch_size):
|
for i in range(0, len(entities), batch_size):
|
||||||
batch = entities[i : i + batch_size]
|
batch = entities[i : i + batch_size]
|
||||||
to_index = []
|
to_index = []
|
||||||
|
@ -18,6 +18,7 @@ class VLMSettings(BaseModel):
|
|||||||
endpoint: str = "http://localhost:11434"
|
endpoint: str = "http://localhost:11434"
|
||||||
token: str = ""
|
token: str = ""
|
||||||
concurrency: int = 4
|
concurrency: int = 4
|
||||||
|
force_jpeg: bool = False # Add this line
|
||||||
|
|
||||||
|
|
||||||
class OCRSettings(BaseModel):
|
class OCRSettings(BaseModel):
|
||||||
@ -64,6 +65,9 @@ class Settings(BaseSettings):
|
|||||||
# Embedding settings
|
# Embedding settings
|
||||||
embedding: EmbeddingSettings = EmbeddingSettings()
|
embedding: EmbeddingSettings = EmbeddingSettings()
|
||||||
|
|
||||||
|
# New batchsize setting
|
||||||
|
batchsize: int = 4
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def settings_customise_sources(
|
def settings_customise_sources(
|
||||||
cls,
|
cls,
|
||||||
|
@ -8,7 +8,7 @@ from memos.schemas import Entity, MetadataType
|
|||||||
import logging
|
import logging
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import os
|
import os
|
||||||
|
import io
|
||||||
|
|
||||||
PLUGIN_NAME = "vlm"
|
PLUGIN_NAME = "vlm"
|
||||||
PROMPT = "描述这张图片的内容"
|
PROMPT = "描述这张图片的内容"
|
||||||
@ -20,6 +20,7 @@ endpoint = None
|
|||||||
token = None
|
token = None
|
||||||
concurrency = None
|
concurrency = None
|
||||||
semaphore = None
|
semaphore = None
|
||||||
|
force_jpeg = None
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -28,13 +29,24 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def image2base64(img_path):
|
def image2base64(img_path):
|
||||||
try:
|
try:
|
||||||
# Attempt to open and verify the image
|
|
||||||
with Image.open(img_path) as img:
|
with Image.open(img_path) as img:
|
||||||
img.verify() # Verify that it's a valid image file
|
img.verify()
|
||||||
|
|
||||||
# If verification passes, encode the image
|
with Image.open(img_path) as img:
|
||||||
with open(img_path, "rb") as image_file:
|
if force_jpeg:
|
||||||
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
# Convert image to RGB mode (removes alpha channel if present)
|
||||||
|
img = img.convert('RGB')
|
||||||
|
# Save as JPEG in memory
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
img.save(buffer, format='JPEG')
|
||||||
|
buffer.seek(0)
|
||||||
|
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||||
|
else:
|
||||||
|
# Use original format
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
img.save(buffer, format=img.format)
|
||||||
|
buffer.seek(0)
|
||||||
|
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||||
return encoded_string
|
return encoded_string
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||||
@ -72,20 +84,19 @@ async def predict(
|
|||||||
if not img_base64:
|
if not img_base64:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get the file extension
|
mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True
|
||||||
_, file_extension = os.path.splitext(img_path)
|
|
||||||
file_extension = file_extension.lower()[
|
|
||||||
1:
|
|
||||||
] # Remove the dot and convert to lowercase
|
|
||||||
|
|
||||||
# Determine the MIME type
|
if not force_jpeg:
|
||||||
mime_types = {
|
# Only determine MIME type if not forcing JPEG
|
||||||
"png": "image/png",
|
_, file_extension = os.path.splitext(img_path)
|
||||||
"jpg": "image/jpeg",
|
file_extension = file_extension.lower()[1:]
|
||||||
"jpeg": "image/jpeg",
|
mime_types = {
|
||||||
"webp": "image/webp",
|
"png": "image/png",
|
||||||
}
|
"jpg": "image/jpeg",
|
||||||
mime_type = mime_types.get(file_extension, "image/jpeg")
|
"jpeg": "image/jpeg",
|
||||||
|
"webp": "image/webp",
|
||||||
|
}
|
||||||
|
mime_type = mime_types.get(file_extension, "image/jpeg")
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"model": modelname,
|
"model": modelname,
|
||||||
@ -188,11 +199,12 @@ async def vlm(entity: Entity, request: Request):
|
|||||||
|
|
||||||
|
|
||||||
def init_plugin(config):
|
def init_plugin(config):
|
||||||
global modelname, endpoint, token, concurrency, semaphore
|
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
|
||||||
modelname = config.modelname
|
modelname = config.modelname
|
||||||
endpoint = config.endpoint
|
endpoint = config.endpoint
|
||||||
token = config.token
|
token = config.token
|
||||||
concurrency = config.concurrency
|
concurrency = config.concurrency
|
||||||
|
force_jpeg = config.force_jpeg
|
||||||
semaphore = asyncio.Semaphore(concurrency)
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
# Print the parameters
|
# Print the parameters
|
||||||
@ -201,6 +213,7 @@ def init_plugin(config):
|
|||||||
logger.info(f"Endpoint: {endpoint}")
|
logger.info(f"Endpoint: {endpoint}")
|
||||||
logger.info(f"Token: {token}")
|
logger.info(f"Token: {token}")
|
||||||
logger.info(f"Concurrency: {concurrency}")
|
logger.info(f"Concurrency: {concurrency}")
|
||||||
|
logger.info(f"Force JPEG: {force_jpeg}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user