diff --git a/memos_ml_backends/florence2_server.py b/memos_ml_backends/florence2_server.py new file mode 100644 index 0000000..3b83707 --- /dev/null +++ b/memos_ml_backends/florence2_server.py @@ -0,0 +1,176 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing import List, Dict, Any, Optional +import httpx +import torch +from PIL import Image +import base64 +import io +from transformers import AutoProcessor, AutoModelForCausalLM +import time +from memos_ml_backends.schemas import ( + ChatCompletionRequest, + ChatCompletionResponse, + ModelData, + ModelsResponse, + get_image_from_url, +) + +MODEL_INFO = {"name": "florence2-base-ft", "max_model_len": 2048} + +# 检测可用的设备 +if torch.cuda.is_available(): + device = torch.device("cuda") +elif torch.backends.mps.is_available(): + device = torch.device("mps") +else: + device = torch.device("cpu") + +torch_dtype = ( + torch.float32 + if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6) + or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) + else torch.float16 +) +print(f"Using device: {device}") + +# Load Florence-2 model +florence_model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, +).to(device, torch_dtype) +florence_processor = AutoProcessor.from_pretrained( + "microsoft/Florence-2-base-ft", trust_remote_code=True +) + +app = FastAPI() + + +async def generate_florence_result(text_input, image_input, max_tokens): + task_prompt = "" + prompt = task_prompt + "" + + inputs = florence_processor( + text=prompt, images=image_input, return_tensors="pt" + ).to(device, torch_dtype) + + generated_ids = florence_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=max_tokens or 1024, + do_sample=False, + num_beams=3, + ) + + generated_texts = florence_processor.batch_decode( + generated_ids, skip_special_tokens=False + ) + + parsed_answer = florence_processor.post_process_generation( + generated_texts[0], + task=task_prompt, + image_size=(image_input.width, image_input.height), + ) + + return parsed_answer.get(task_prompt, "") + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions(request: ChatCompletionRequest): + try: + last_message = request.messages[-1] + text_input = last_message.get("content", "") + image_input = None + + if isinstance(text_input, list): + for content in text_input: + if content.get("type") == "image_url": + image_url = content["image_url"].get("url") + image_input = await get_image_from_url(image_url) + break + text_input = " ".join( + [ + content["text"] + for content in text_input + if content.get("type") == "text" + ] + ) + + if image_input is None: + raise ValueError("Image input is required") + + parsed_answer = await generate_florence_result( + text_input, image_input, request.max_tokens + ) + + result = ChatCompletionResponse( + id=str(int(time.time())), + object="chat.completion", + created=int(time.time()), + model=request.model, + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": parsed_answer, + }, + "finish_reason": "stop", + } + ], + usage={ + "prompt_tokens": 0, + "total_tokens": 0, + "completion_tokens": 0, + }, + ) + + return result + except Exception as e: + print(f"Error generating chat completion: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Error generating chat completion: {str(e)}" + ) + + +@app.get("/v1/models", response_model=ModelsResponse) +async def get_models(): + model_data = ModelData( + id=MODEL_INFO["name"], + created=int(time.time()), + max_model_len=MODEL_INFO["max_model_len"], + permission=[ + { + "id": f"modelperm-{MODEL_INFO['name']}", + "object": "model_permission", + "created": int(time.time()), + "allow_create_engine": False, + "allow_sampling": False, + "allow_logprobs": False, + "allow_search_indices": False, + "allow_view": False, + "allow_fine_tuning": False, + "organization": "*", + "group": None, + "is_blocking": False, + } + ], + ) + + return ModelsResponse(data=[model_data]) + + +if __name__ == "__main__": + import argparse + import uvicorn + + parser = argparse.ArgumentParser(description="Run the Florence-2 server") + parser.add_argument( + "--port", type=int, default=8000, help="Port to run the server on" + ) + args = parser.parse_args() + + print("Using Florence-2 model") + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/memos_ml_backends/qwen2vl_server.py b/memos_ml_backends/qwen2vl_server.py new file mode 100644 index 0000000..bb08b31 --- /dev/null +++ b/memos_ml_backends/qwen2vl_server.py @@ -0,0 +1,182 @@ +from fastapi import FastAPI, HTTPException +import torch +from transformers import AutoProcessor, Qwen2VLForConditionalGeneration +from qwen_vl_utils import process_vision_info +import time +from memos_ml_backends.schemas import ( + ChatCompletionRequest, + ChatCompletionResponse, + ModelData, + ModelsResponse, + get_image_from_url, +) + +MODEL_INFO = {"name": "Qwen2-VL-2B-Instruct", "max_model_len": 32768} + +# 检测可用的设备 +if torch.cuda.is_available(): + device = torch.device("cuda") +elif torch.backends.mps.is_available(): + device = torch.device("mps") +else: + device = torch.device("cpu") + +torch_dtype = ( + torch.float32 + if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6) + or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) + else torch.float16 +) +print(f"Using device: {device}") + +# Load Qwen2VL model +qwen2vl_model = Qwen2VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-VL-2B-Instruct", + torch_dtype=torch_dtype, + device_map="auto", +).to(device, torch_dtype) +qwen2vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4") + +app = FastAPI() + + +async def generate_qwen2vl_result(text_input, image_input, max_tokens): + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_input}, + {"type": "text", "text": text_input}, + ], + } + ] + + text = qwen2vl_processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + image_inputs, video_inputs = process_vision_info(messages) + + inputs = qwen2vl_processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(device) + + generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512)) + + generated_ids_trimmed = [ + out_ids[len(in_ids) :] + for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + + output_text = qwen2vl_processor.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + return output_text[0] if output_text else "" + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions(request: ChatCompletionRequest): + try: + last_message = request.messages[-1] + text_input = last_message.get("content", "") + image_input = None + + if isinstance(text_input, list): + for content in text_input: + if content.get("type") == "image_url": + image_url = content["image_url"].get("url") + image_input = await get_image_from_url(image_url) + break + text_input = " ".join( + [ + content["text"] + for content in text_input + if content.get("type") == "text" + ] + ) + + if image_input is None: + raise ValueError("Image input is required") + + parsed_answer = await generate_qwen2vl_result( + text_input, image_input, request.max_tokens + ) + + result = ChatCompletionResponse( + id=str(int(time.time())), + object="chat.completion", + created=int(time.time()), + model=request.model, + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": parsed_answer, + }, + "finish_reason": "stop", + } + ], + usage={ + "prompt_tokens": 0, + "total_tokens": 0, + "completion_tokens": 0, + }, + ) + + return result + except Exception as e: + print(f"Error generating chat completion: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Error generating chat completion: {str(e)}" + ) + + +# 添加新的 GET /v1/models 端点 +@app.get("/v1/models", response_model=ModelsResponse) +async def get_models(): + model_data = ModelData( + id=MODEL_INFO["name"], + created=int(time.time()), + max_model_len=MODEL_INFO["max_model_len"], + permission=[ + { + "id": f"modelperm-{MODEL_INFO['name']}", + "object": "model_permission", + "created": int(time.time()), + "allow_create_engine": False, + "allow_sampling": False, + "allow_logprobs": False, + "allow_search_indices": False, + "allow_view": False, + "allow_fine_tuning": False, + "organization": "*", + "group": None, + "is_blocking": False, + } + ], + ) + + return ModelsResponse(data=[model_data]) + + +if __name__ == "__main__": + import argparse + import uvicorn + + parser = argparse.ArgumentParser(description="Run the Qwen2VL server") + parser.add_argument( + "--port", type=int, default=8000, help="Port to run the server on" + ) + args = parser.parse_args() + + print("Using Qwen2VL model") + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/memos_ml_backends/requirements.txt b/memos_ml_backends/requirements.txt index a03b629..c3d5fc6 100644 --- a/memos_ml_backends/requirements.txt +++ b/memos_ml_backends/requirements.txt @@ -2,7 +2,6 @@ einops timm transformers sentence-transformers -git+https://github.com/huggingface/transformers +transformers qwen-vl-utils -auto-gptq -optimum \ No newline at end of file +optimum diff --git a/memos_ml_backends/schemas.py b/memos_ml_backends/schemas.py new file mode 100644 index 0000000..70c2344 --- /dev/null +++ b/memos_ml_backends/schemas.py @@ -0,0 +1,48 @@ +from pydantic import BaseModel +from typing import List, Dict, Any, Optional +import httpx +from PIL import Image +import base64 +import io + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[Dict[str, Any]] + max_tokens: Optional[int] = None + +class ChatCompletionResponse(BaseModel): + id: str + object: str + created: int + model: str + choices: List[Dict[str, Any]] + usage: Dict[str, int] + +class ModelData(BaseModel): + id: str + object: str = "model" + created: int + owned_by: str = "transformers" + root: str = "models" + parent: Optional[str] = None + max_model_len: int + permission: List[Dict[str, Any]] + +class ModelsResponse(BaseModel): + object: str = "list" + data: List[ModelData] + +async def get_image_from_url(image_url): + if image_url.startswith("data:image/"): + image_data = base64.b64decode(image_url.split(",")[1]) + return Image.open(io.BytesIO(image_data)) + elif image_url.startswith("file://"): + file_path = image_url[len("file://") :] + return Image.open(file_path) + else: + async with httpx.AsyncClient() as client: + response = await client.get(image_url) + response.raise_for_status() + image_data = response.content + return Image.open(io.BytesIO(image_data)) + diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py deleted file mode 100644 index b936668..0000000 --- a/memos_ml_backends/server.py +++ /dev/null @@ -1,260 +0,0 @@ -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from typing import List, Dict, Any, Optional -import numpy as np -import httpx -import torch -from PIL import Image -import base64 -import io -from transformers import ( - AutoProcessor, - AutoModelForCausalLM, - Qwen2VLForConditionalGeneration, -) -from qwen_vl_utils import process_vision_info -import time -import argparse - -# 检测可用的设备 -if torch.cuda.is_available(): - device = torch.device("cuda") -elif torch.backends.mps.is_available(): - device = torch.device("mps") -else: - device = torch.device("cpu") - -torch_dtype = ( - torch.float32 - if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6) - or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) - else torch.float16 -) -print(f"Using device: {device}") - - -# Add a configuration option to choose the model -parser = argparse.ArgumentParser(description="Run the server with specified model") -parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") -parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model") -args = parser.parse_args() - -# Replace the USE_FLORANCE_MODEL configuration with this -use_florence_model = args.florence if (args.florence or args.qwen2vl) else True - -# Initialize models based on the configuration -if use_florence_model: - # Load Florence-2 model - florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", - torch_dtype=torch_dtype, - attn_implementation="sdpa", - trust_remote_code=True, - ).to(device, torch_dtype) - florence_processor = AutoProcessor.from_pretrained( - "microsoft/Florence-2-base-ft", trust_remote_code=True - ) -else: - # Load Qwen2VL model - qwen2vl_model = Qwen2VLForConditionalGeneration.from_pretrained( - "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", - torch_dtype=torch_dtype, - device_map="auto", - ).to(device, torch_dtype) - qwen2vl_processor = AutoProcessor.from_pretrained( - "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4" - ) - - -async def get_image_from_url(image_url): - if image_url.startswith("data:image/"): - image_data = base64.b64decode(image_url.split(",")[1]) - return Image.open(io.BytesIO(image_data)) - elif image_url.startswith("file://"): - file_path = image_url[len("file://") :] - return Image.open(file_path) - else: - async with httpx.AsyncClient() as client: - response = await client.get(image_url) - response.raise_for_status() - image_data = response.content - return Image.open(io.BytesIO(image_data)) - - -async def generate_florence_result(text_input, image_input, max_tokens): - task_prompt = "" - prompt = task_prompt + "" - - inputs = florence_processor( - text=prompt, images=image_input, return_tensors="pt" - ).to(device, torch_dtype) - - generated_ids = florence_model.generate( - input_ids=inputs["input_ids"], - pixel_values=inputs["pixel_values"], - max_new_tokens=max_tokens or 1024, - do_sample=False, - num_beams=3, - ) - - generated_texts = florence_processor.batch_decode( - generated_ids, skip_special_tokens=False - ) - - # 处理生成的文本 - parsed_answer = florence_processor.post_process_generation( - generated_texts[0], - task=task_prompt, - image_size=(image_input.width, image_input.height), - ) - - return parsed_answer.get(task_prompt, "") - - -async def generate_qwen2vl_result(text_input, image_input, max_tokens): - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": image_input}, - {"type": "text", "text": text_input}, - ], - } - ] - - text = qwen2vl_processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - image_inputs, video_inputs = process_vision_info(messages) - - inputs = qwen2vl_processor( - text=[text], - images=image_inputs, - videos=video_inputs, - padding=True, - return_tensors="pt", - ) - inputs = inputs.to(device) - - generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512)) - - generated_ids_trimmed = [ - out_ids[len(in_ids) :] - for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] - - output_text = qwen2vl_processor.batch_decode( - generated_ids_trimmed, - skip_special_tokens=True, - clean_up_tokenization_spaces=False, - ) - - return output_text[0] if output_text else "" - - -app = FastAPI() - - -class ChatCompletionRequest(BaseModel): - model: str - messages: List[Dict[str, Any]] - max_tokens: Optional[int] = None - - -class ChatCompletionResponse(BaseModel): - id: str - object: str - created: int - model: str - choices: List[Dict[str, Any]] - usage: Dict[str, int] - - -@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) -async def chat_completions(request: ChatCompletionRequest): - try: - last_message = request.messages[-1] - text_input = last_message.get("content", "") - image_input = None - - # Process text and image input - if isinstance(text_input, list): - for content in text_input: - if content.get("type") == "image_url": - image_url = content["image_url"].get("url") - image_input = await get_image_from_url(image_url) - break - text_input = " ".join( - [ - content["text"] - for content in text_input - if content.get("type") == "text" - ] - ) - - if image_input is None: - raise ValueError("Image input is required") - - # Use the selected model for generation - if use_florence_model: - parsed_answer = await generate_florence_result( - text_input, image_input, request.max_tokens - ) - else: - parsed_answer = await generate_qwen2vl_result( - text_input, image_input, request.max_tokens - ) - - result = ChatCompletionResponse( - id=str(int(time.time())), - object="chat.completion", - created=int(time.time()), - model=request.model, - choices=[ - { - "index": 0, - "message": { - "role": "assistant", - "content": parsed_answer, - }, - "finish_reason": "stop", - } - ], - usage={ - "prompt_tokens": 0, - "total_tokens": 0, - "completion_tokens": 0, - }, - ) - - return result - except Exception as e: - print(f"Error generating chat completion: {str(e)}") - raise HTTPException( - status_code=500, detail=f"Error generating chat completion: {str(e)}" - ) - - -if __name__ == "__main__": - import uvicorn - - parser = argparse.ArgumentParser( - description="Run the server with specified model and port" - ) - parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") - parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model") - parser.add_argument( - "--port", type=int, default=8000, help="Port to run the server on" - ) - args = parser.parse_args() - - if args.florence and args.qwen2vl: - print("Error: Please specify only one model (--florence or --qwen2vl)") - exit(1) - elif not args.florence and not args.qwen2vl: - print("No model specified, using default (Florence-2)") - - use_florence_model = args.florence if (args.florence or args.qwen2vl) else True - print(f"Using {'Florence-2' if use_florence_model else 'Qwen2VL'} model") - uvicorn.run(app, host="0.0.0.0", port=args.port)