refact(ml_backend): separate servers

This commit is contained in:
arkohut 2024-10-18 15:31:13 +08:00
parent ad779b1b58
commit 189b82739d
5 changed files with 408 additions and 263 deletions

View File

@ -0,0 +1,176 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import httpx
import torch
from PIL import Image
import base64
import io
from transformers import AutoProcessor, AutoModelForCausalLM
import time
from memos_ml_backends.schemas import (
ChatCompletionRequest,
ChatCompletionResponse,
ModelData,
ModelsResponse,
get_image_from_url,
)
MODEL_INFO = {"name": "florence2-base-ft", "max_model_len": 2048}
# 检测可用的设备
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
torch_dtype = (
torch.float32
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
print(f"Using device: {device}")
# Load Florence-2 model
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device, torch_dtype)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
)
app = FastAPI()
async def generate_florence_result(text_input, image_input, max_tokens):
task_prompt = "<MORE_DETAILED_CAPTION>"
prompt = task_prompt + ""
inputs = florence_processor(
text=prompt, images=image_input, return_tensors="pt"
).to(device, torch_dtype)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=max_tokens or 1024,
do_sample=False,
num_beams=3,
)
generated_texts = florence_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
parsed_answer = florence_processor.post_process_generation(
generated_texts[0],
task=task_prompt,
image_size=(image_input.width, image_input.height),
)
return parsed_answer.get(task_prompt, "")
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(request: ChatCompletionRequest):
try:
last_message = request.messages[-1]
text_input = last_message.get("content", "")
image_input = None
if isinstance(text_input, list):
for content in text_input:
if content.get("type") == "image_url":
image_url = content["image_url"].get("url")
image_input = await get_image_from_url(image_url)
break
text_input = " ".join(
[
content["text"]
for content in text_input
if content.get("type") == "text"
]
)
if image_input is None:
raise ValueError("Image input is required")
parsed_answer = await generate_florence_result(
text_input, image_input, request.max_tokens
)
result = ChatCompletionResponse(
id=str(int(time.time())),
object="chat.completion",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"message": {
"role": "assistant",
"content": parsed_answer,
},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": 0,
"total_tokens": 0,
"completion_tokens": 0,
},
)
return result
except Exception as e:
print(f"Error generating chat completion: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error generating chat completion: {str(e)}"
)
@app.get("/v1/models", response_model=ModelsResponse)
async def get_models():
model_data = ModelData(
id=MODEL_INFO["name"],
created=int(time.time()),
max_model_len=MODEL_INFO["max_model_len"],
permission=[
{
"id": f"modelperm-{MODEL_INFO['name']}",
"object": "model_permission",
"created": int(time.time()),
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": False,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False,
}
],
)
return ModelsResponse(data=[model_data])
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(description="Run the Florence-2 server")
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args()
print("Using Florence-2 model")
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -0,0 +1,182 @@
from fastapi import FastAPI, HTTPException
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from qwen_vl_utils import process_vision_info
import time
from memos_ml_backends.schemas import (
ChatCompletionRequest,
ChatCompletionResponse,
ModelData,
ModelsResponse,
get_image_from_url,
)
MODEL_INFO = {"name": "Qwen2-VL-2B-Instruct", "max_model_len": 32768}
# 检测可用的设备
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
torch_dtype = (
torch.float32
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
print(f"Using device: {device}")
# Load Qwen2VL model
qwen2vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
torch_dtype=torch_dtype,
device_map="auto",
).to(device, torch_dtype)
qwen2vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
app = FastAPI()
async def generate_qwen2vl_result(text_input, image_input, max_tokens):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_input},
{"type": "text", "text": text_input},
],
}
]
text = qwen2vl_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen2vl_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = qwen2vl_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return output_text[0] if output_text else ""
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(request: ChatCompletionRequest):
try:
last_message = request.messages[-1]
text_input = last_message.get("content", "")
image_input = None
if isinstance(text_input, list):
for content in text_input:
if content.get("type") == "image_url":
image_url = content["image_url"].get("url")
image_input = await get_image_from_url(image_url)
break
text_input = " ".join(
[
content["text"]
for content in text_input
if content.get("type") == "text"
]
)
if image_input is None:
raise ValueError("Image input is required")
parsed_answer = await generate_qwen2vl_result(
text_input, image_input, request.max_tokens
)
result = ChatCompletionResponse(
id=str(int(time.time())),
object="chat.completion",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"message": {
"role": "assistant",
"content": parsed_answer,
},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": 0,
"total_tokens": 0,
"completion_tokens": 0,
},
)
return result
except Exception as e:
print(f"Error generating chat completion: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error generating chat completion: {str(e)}"
)
# 添加新的 GET /v1/models 端点
@app.get("/v1/models", response_model=ModelsResponse)
async def get_models():
model_data = ModelData(
id=MODEL_INFO["name"],
created=int(time.time()),
max_model_len=MODEL_INFO["max_model_len"],
permission=[
{
"id": f"modelperm-{MODEL_INFO['name']}",
"object": "model_permission",
"created": int(time.time()),
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": False,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False,
}
],
)
return ModelsResponse(data=[model_data])
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(description="Run the Qwen2VL server")
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args()
print("Using Qwen2VL model")
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -2,7 +2,6 @@ einops
timm
transformers
sentence-transformers
git+https://github.com/huggingface/transformers
transformers
qwen-vl-utils
auto-gptq
optimum
optimum

View File

@ -0,0 +1,48 @@
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import httpx
from PIL import Image
import base64
import io
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, Any]]
max_tokens: Optional[int] = None
class ChatCompletionResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
class ModelData(BaseModel):
id: str
object: str = "model"
created: int
owned_by: str = "transformers"
root: str = "models"
parent: Optional[str] = None
max_model_len: int
permission: List[Dict[str, Any]]
class ModelsResponse(BaseModel):
object: str = "list"
data: List[ModelData]
async def get_image_from_url(image_url):
if image_url.startswith("data:image/"):
image_data = base64.b64decode(image_url.split(",")[1])
return Image.open(io.BytesIO(image_data))
elif image_url.startswith("file://"):
file_path = image_url[len("file://") :]
return Image.open(file_path)
else:
async with httpx.AsyncClient() as client:
response = await client.get(image_url)
response.raise_for_status()
image_data = response.content
return Image.open(io.BytesIO(image_data))

View File

@ -1,260 +0,0 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import numpy as np
import httpx
import torch
from PIL import Image
import base64
import io
from transformers import (
AutoProcessor,
AutoModelForCausalLM,
Qwen2VLForConditionalGeneration,
)
from qwen_vl_utils import process_vision_info
import time
import argparse
# 检测可用的设备
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
torch_dtype = (
torch.float32
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
print(f"Using device: {device}")
# Add a configuration option to choose the model
parser = argparse.ArgumentParser(description="Run the server with specified model")
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
args = parser.parse_args()
# Replace the USE_FLORANCE_MODEL configuration with this
use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
# Initialize models based on the configuration
if use_florence_model:
# Load Florence-2 model
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device, torch_dtype)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
)
else:
# Load Qwen2VL model
qwen2vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
torch_dtype=torch_dtype,
device_map="auto",
).to(device, torch_dtype)
qwen2vl_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
)
async def get_image_from_url(image_url):
if image_url.startswith("data:image/"):
image_data = base64.b64decode(image_url.split(",")[1])
return Image.open(io.BytesIO(image_data))
elif image_url.startswith("file://"):
file_path = image_url[len("file://") :]
return Image.open(file_path)
else:
async with httpx.AsyncClient() as client:
response = await client.get(image_url)
response.raise_for_status()
image_data = response.content
return Image.open(io.BytesIO(image_data))
async def generate_florence_result(text_input, image_input, max_tokens):
task_prompt = "<MORE_DETAILED_CAPTION>"
prompt = task_prompt + ""
inputs = florence_processor(
text=prompt, images=image_input, return_tensors="pt"
).to(device, torch_dtype)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=max_tokens or 1024,
do_sample=False,
num_beams=3,
)
generated_texts = florence_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
# 处理生成的文本
parsed_answer = florence_processor.post_process_generation(
generated_texts[0],
task=task_prompt,
image_size=(image_input.width, image_input.height),
)
return parsed_answer.get(task_prompt, "")
async def generate_qwen2vl_result(text_input, image_input, max_tokens):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_input},
{"type": "text", "text": text_input},
],
}
]
text = qwen2vl_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen2vl_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = qwen2vl_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return output_text[0] if output_text else ""
app = FastAPI()
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, Any]]
max_tokens: Optional[int] = None
class ChatCompletionResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(request: ChatCompletionRequest):
try:
last_message = request.messages[-1]
text_input = last_message.get("content", "")
image_input = None
# Process text and image input
if isinstance(text_input, list):
for content in text_input:
if content.get("type") == "image_url":
image_url = content["image_url"].get("url")
image_input = await get_image_from_url(image_url)
break
text_input = " ".join(
[
content["text"]
for content in text_input
if content.get("type") == "text"
]
)
if image_input is None:
raise ValueError("Image input is required")
# Use the selected model for generation
if use_florence_model:
parsed_answer = await generate_florence_result(
text_input, image_input, request.max_tokens
)
else:
parsed_answer = await generate_qwen2vl_result(
text_input, image_input, request.max_tokens
)
result = ChatCompletionResponse(
id=str(int(time.time())),
object="chat.completion",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"message": {
"role": "assistant",
"content": parsed_answer,
},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": 0,
"total_tokens": 0,
"completion_tokens": 0,
},
)
return result
except Exception as e:
print(f"Error generating chat completion: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error generating chat completion: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
parser = argparse.ArgumentParser(
description="Run the server with specified model and port"
)
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args()
if args.florence and args.qwen2vl:
print("Error: Please specify only one model (--florence or --qwen2vl)")
exit(1)
elif not args.florence and not args.qwen2vl:
print("No model specified, using default (Florence-2)")
use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
print(f"Using {'Florence-2' if use_florence_model else 'Qwen2VL'} model")
uvicorn.run(app, host="0.0.0.0", port=args.port)