pensieve/memos_ml_backends/qwen2vl_server.py
2024-10-18 15:31:13 +08:00

183 lines
5.3 KiB
Python

from fastapi import FastAPI, HTTPException
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from qwen_vl_utils import process_vision_info
import time
from memos_ml_backends.schemas import (
ChatCompletionRequest,
ChatCompletionResponse,
ModelData,
ModelsResponse,
get_image_from_url,
)
MODEL_INFO = {"name": "Qwen2-VL-2B-Instruct", "max_model_len": 32768}
# 检测可用的设备
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
torch_dtype = (
torch.float32
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
print(f"Using device: {device}")
# Load Qwen2VL model
qwen2vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
torch_dtype=torch_dtype,
device_map="auto",
).to(device, torch_dtype)
qwen2vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
app = FastAPI()
async def generate_qwen2vl_result(text_input, image_input, max_tokens):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_input},
{"type": "text", "text": text_input},
],
}
]
text = qwen2vl_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen2vl_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = qwen2vl_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return output_text[0] if output_text else ""
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(request: ChatCompletionRequest):
try:
last_message = request.messages[-1]
text_input = last_message.get("content", "")
image_input = None
if isinstance(text_input, list):
for content in text_input:
if content.get("type") == "image_url":
image_url = content["image_url"].get("url")
image_input = await get_image_from_url(image_url)
break
text_input = " ".join(
[
content["text"]
for content in text_input
if content.get("type") == "text"
]
)
if image_input is None:
raise ValueError("Image input is required")
parsed_answer = await generate_qwen2vl_result(
text_input, image_input, request.max_tokens
)
result = ChatCompletionResponse(
id=str(int(time.time())),
object="chat.completion",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"message": {
"role": "assistant",
"content": parsed_answer,
},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": 0,
"total_tokens": 0,
"completion_tokens": 0,
},
)
return result
except Exception as e:
print(f"Error generating chat completion: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error generating chat completion: {str(e)}"
)
# 添加新的 GET /v1/models 端点
@app.get("/v1/models", response_model=ModelsResponse)
async def get_models():
model_data = ModelData(
id=MODEL_INFO["name"],
created=int(time.time()),
max_model_len=MODEL_INFO["max_model_len"],
permission=[
{
"id": f"modelperm-{MODEL_INFO['name']}",
"object": "model_permission",
"created": int(time.time()),
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": False,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False,
}
],
)
return ModelsResponse(data=[model_data])
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(description="Run the Qwen2VL server")
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args()
print("Using Qwen2VL model")
uvicorn.run(app, host="0.0.0.0", port=args.port)