mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 19:25:24 +00:00
feat(ml_backend): move florence 2 as a default vlm plugin
This commit is contained in:
parent
d7e6c32e86
commit
7e43bc0861
@ -19,6 +19,7 @@ class VLMSettings(BaseModel):
|
|||||||
token: str = ""
|
token: str = ""
|
||||||
concurrency: int = 4
|
concurrency: int = 4
|
||||||
force_jpeg: bool = False
|
force_jpeg: bool = False
|
||||||
|
use_local: bool = True
|
||||||
|
|
||||||
|
|
||||||
class OCRSettings(BaseModel):
|
class OCRSettings(BaseModel):
|
||||||
|
@ -145,6 +145,7 @@ async def ocr(entity: Entity, request: Request):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the patch request was successful
|
# Check if the patch request was successful
|
||||||
|
@ -9,6 +9,8 @@ import logging
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||||
|
|
||||||
PLUGIN_NAME = "vlm"
|
PLUGIN_NAME = "vlm"
|
||||||
PROMPT = "描述这张图片的内容"
|
PROMPT = "描述这张图片的内容"
|
||||||
@ -21,6 +23,10 @@ token = None
|
|||||||
concurrency = None
|
concurrency = None
|
||||||
semaphore = None
|
semaphore = None
|
||||||
force_jpeg = None
|
force_jpeg = None
|
||||||
|
use_local = None
|
||||||
|
florence_model = None
|
||||||
|
florence_processor = None
|
||||||
|
torch_dtype = None
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -35,18 +41,18 @@ def image2base64(img_path):
|
|||||||
with Image.open(img_path) as img:
|
with Image.open(img_path) as img:
|
||||||
if force_jpeg:
|
if force_jpeg:
|
||||||
# Convert image to RGB mode (removes alpha channel if present)
|
# Convert image to RGB mode (removes alpha channel if present)
|
||||||
img = img.convert('RGB')
|
img = img.convert("RGB")
|
||||||
# Save as JPEG in memory
|
# Save as JPEG in memory
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
img.save(buffer, format='JPEG')
|
img.save(buffer, format="JPEG")
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
else:
|
else:
|
||||||
# Use original format
|
# Use original format
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
img.save(buffer, format=img.format)
|
img.save(buffer, format=img.format)
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
return encoded_string
|
return encoded_string
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||||
@ -79,12 +85,57 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
|
|||||||
|
|
||||||
async def predict(
|
async def predict(
|
||||||
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
if use_local:
|
||||||
|
return await predict_local(img_path)
|
||||||
|
else:
|
||||||
|
return await predict_remote(endpoint, modelname, img_path, token)
|
||||||
|
|
||||||
|
|
||||||
|
async def predict_local(img_path: str) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
image = Image.open(img_path)
|
||||||
|
task_prompt = "<MORE_DETAILED_CAPTION>"
|
||||||
|
prompt = task_prompt + ""
|
||||||
|
|
||||||
|
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(
|
||||||
|
florence_model.device, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_ids = florence_model.generate(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
pixel_values=inputs["pixel_values"],
|
||||||
|
max_new_tokens=1024,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_texts = florence_processor.batch_decode(
|
||||||
|
generated_ids, skip_special_tokens=False
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_answer = florence_processor.post_process_generation(
|
||||||
|
generated_texts[0],
|
||||||
|
task=task_prompt,
|
||||||
|
image_size=(image.width, image.height),
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed_answer.get(task_prompt, "")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def predict_remote(
|
||||||
|
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
img_base64 = image2base64(img_path)
|
img_base64 = image2base64(img_path)
|
||||||
if not img_base64:
|
if not img_base64:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True
|
mime_type = (
|
||||||
|
"image/jpeg" if force_jpeg else "image/jpeg"
|
||||||
|
) # Default to JPEG if force_jpeg is True
|
||||||
|
|
||||||
if not force_jpeg:
|
if not force_jpeg:
|
||||||
# Only determine MIME type if not forcing JPEG
|
# Only determine MIME type if not forcing JPEG
|
||||||
@ -167,9 +218,9 @@ async def vlm(entity: Entity, request: Request):
|
|||||||
|
|
||||||
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
|
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
|
||||||
|
|
||||||
print(vlm_result)
|
logger.info(vlm_result)
|
||||||
if not vlm_result:
|
if not vlm_result:
|
||||||
print(f"No VLM result found for file: {entity.filepath}")
|
logger.info(f"No VLM result found for file: {entity.filepath}")
|
||||||
return {metadata_field_name: "{}"}
|
return {metadata_field_name: "{}"}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
@ -199,14 +250,46 @@ async def vlm(entity: Entity, request: Request):
|
|||||||
|
|
||||||
|
|
||||||
def init_plugin(config):
|
def init_plugin(config):
|
||||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
|
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
|
||||||
|
|
||||||
modelname = config.modelname
|
modelname = config.modelname
|
||||||
endpoint = config.endpoint
|
endpoint = config.endpoint
|
||||||
token = config.token
|
token = config.token
|
||||||
concurrency = config.concurrency
|
concurrency = config.concurrency
|
||||||
force_jpeg = config.force_jpeg
|
force_jpeg = config.force_jpeg
|
||||||
|
use_local = config.use_local
|
||||||
semaphore = asyncio.Semaphore(concurrency)
|
semaphore = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
if use_local:
|
||||||
|
# 检测可用的设备
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
torch_dtype = (
|
||||||
|
torch.float32
|
||||||
|
if (
|
||||||
|
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6
|
||||||
|
)
|
||||||
|
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
|
||||||
|
else torch.float16
|
||||||
|
)
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"microsoft/Florence-2-base-ft",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
attn_implementation="sdpa",
|
||||||
|
trust_remote_code=True,
|
||||||
|
).to(device)
|
||||||
|
florence_processor = AutoProcessor.from_pretrained(
|
||||||
|
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
||||||
|
)
|
||||||
|
logger.info("Florence model and processor initialized")
|
||||||
|
|
||||||
# Print the parameters
|
# Print the parameters
|
||||||
logger.info("VLM plugin initialized")
|
logger.info("VLM plugin initialized")
|
||||||
logger.info(f"Model Name: {modelname}")
|
logger.info(f"Model Name: {modelname}")
|
||||||
@ -214,6 +297,7 @@ def init_plugin(config):
|
|||||||
logger.info(f"Token: {token}")
|
logger.info(f"Token: {token}")
|
||||||
logger.info(f"Concurrency: {concurrency}")
|
logger.info(f"Concurrency: {concurrency}")
|
||||||
logger.info(f"Force JPEG: {force_jpeg}")
|
logger.info(f"Force JPEG: {force_jpeg}")
|
||||||
|
logger.info(f"Use Local: {use_local}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -50,7 +50,7 @@ if use_florence_model:
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
attn_implementation="sdpa",
|
attn_implementation="sdpa",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
).to(device)
|
).to(device, torch_dtype)
|
||||||
florence_processor = AutoProcessor.from_pretrained(
|
florence_processor = AutoProcessor.from_pretrained(
|
||||||
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
||||||
)
|
)
|
||||||
@ -60,7 +60,7 @@ else:
|
|||||||
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
).to(device)
|
).to(device, torch_dtype)
|
||||||
qwen2vl_processor = AutoProcessor.from_pretrained(
|
qwen2vl_processor = AutoProcessor.from_pretrained(
|
||||||
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
|
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user