diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py index 81df3aa..8eb0653 100644 --- a/memos_ml_backends/server.py +++ b/memos_ml_backends/server.py @@ -25,7 +25,12 @@ elif torch.backends.mps.is_available(): else: device = torch.device("cpu") -torch_dtype = "auto" +torch_dtype = ( + torch.float32 + if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6) + or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) + else torch.float16 +) print(f"Using device: {device}") @@ -37,7 +42,7 @@ def init_embedding_model(): return model -embedding_model = init_embedding_model() # 初始化模型 +embedding_model = init_embedding_model() def generate_embeddings(input_texts: List[str]) -> List[List[float]]: @@ -139,9 +144,9 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): text = qwen2vl_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - + image_inputs, video_inputs = process_vision_info(messages) - + inputs = qwen2vl_processor( text=[text], images=image_inputs, @@ -152,12 +157,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): inputs = inputs.to(device) generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512)) - + generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] - + output_text = qwen2vl_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, @@ -275,10 +280,14 @@ async def chat_completions(request: ChatCompletionRequest): if __name__ == "__main__": import uvicorn - parser = argparse.ArgumentParser(description="Run the server with specified model and port") + parser = argparse.ArgumentParser( + description="Run the server with specified model and port" + ) parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model") - parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") + parser.add_argument( + "--port", type=int, default=8000, help="Port to run the server on" + ) args = parser.parse_args() if args.florence and args.qwen2vl: