From f1f77bb906fa8c3dc4d9a911a0fcb616f2c6f39e Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Fri, 6 Sep 2024 00:39:34 +0800 Subject: [PATCH] feat(ml_backend): remove logs --- memos_ml_backends/requirements.txt | 2 +- memos_ml_backends/server.py | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/memos_ml_backends/requirements.txt b/memos_ml_backends/requirements.txt index 76b99c4..a03b629 100644 --- a/memos_ml_backends/requirements.txt +++ b/memos_ml_backends/requirements.txt @@ -1,5 +1,5 @@ einops -timms +timm transformers sentence-transformers git+https://github.com/huggingface/transformers diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py index 3c50406..81df3aa 100644 --- a/memos_ml_backends/server.py +++ b/memos_ml_backends/server.py @@ -25,13 +25,13 @@ elif torch.backends.mps.is_available(): else: device = torch.device("cpu") -torch_dtype = torch.float16 +torch_dtype = "auto" print(f"Using device: {device}") def init_embedding_model(): model = SentenceTransformer( - "Alibaba-NLP/gte-multilingual-base", trust_remote_code=True + "jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True ) model.to(device) return model @@ -136,11 +136,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): } ] - # Prepare inputs for inference text = qwen2vl_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) + image_inputs, video_inputs = process_vision_info(messages) + inputs = qwen2vl_processor( text=[text], images=image_inputs, @@ -150,12 +151,13 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): ) inputs = inputs.to(device) - # Generate output - generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=max_tokens or 1024) + generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512)) + generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] + output_text = qwen2vl_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, @@ -273,10 +275,18 @@ async def chat_completions(request: ChatCompletionRequest): if __name__ == "__main__": import uvicorn + parser = argparse.ArgumentParser(description="Run the server with specified model and port") + parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") + parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model") + parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") + args = parser.parse_args() + if args.florence and args.qwen2vl: print("Error: Please specify only one model (--florence or --qwen2vl)") + exit(1) elif not args.florence and not args.qwen2vl: print("No model specified, using default (Florence-2)") + use_florence_model = args.florence if (args.florence or args.qwen2vl) else True print(f"Using {'Florence-2' if use_florence_model else 'Qwen2VL'} model") - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run(app, host="0.0.0.0", port=args.port)