feat(ml_backend): remove logs

This commit is contained in:
arkohut 2024-09-06 00:39:34 +08:00
parent e062636c08
commit f1f77bb906
2 changed files with 17 additions and 7 deletions

View File

@ -1,5 +1,5 @@
einops
timms
timm
transformers
sentence-transformers
git+https://github.com/huggingface/transformers

View File

@ -25,13 +25,13 @@ elif torch.backends.mps.is_available():
else:
device = torch.device("cpu")
torch_dtype = torch.float16
torch_dtype = "auto"
print(f"Using device: {device}")
def init_embedding_model():
model = SentenceTransformer(
"Alibaba-NLP/gte-multilingual-base", trust_remote_code=True
"jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True
)
model.to(device)
return model
@ -136,11 +136,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
}
]
# Prepare inputs for inference
text = qwen2vl_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen2vl_processor(
text=[text],
images=image_inputs,
@ -150,12 +151,13 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
)
inputs = inputs.to(device)
# Generate output
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=max_tokens or 1024)
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = qwen2vl_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
@ -273,10 +275,18 @@ async def chat_completions(request: ChatCompletionRequest):
if __name__ == "__main__":
import uvicorn
parser = argparse.ArgumentParser(description="Run the server with specified model and port")
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
args = parser.parse_args()
if args.florence and args.qwen2vl:
print("Error: Please specify only one model (--florence or --qwen2vl)")
exit(1)
elif not args.florence and not args.qwen2vl:
print("No model specified, using default (Florence-2)")
use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
print(f"Using {'Florence-2' if use_florence_model else 'Qwen2VL'} model")
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=args.port)