feat(ml): use float32 for pascal

This commit is contained in:
arkohut 2024-09-06 23:42:46 +08:00
parent dbbc2792ef
commit a5562b14eb

View File

@ -25,7 +25,12 @@ elif torch.backends.mps.is_available():
else:
device = torch.device("cpu")
torch_dtype = "auto"
torch_dtype = (
torch.float32
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
print(f"Using device: {device}")
@ -37,7 +42,7 @@ def init_embedding_model():
return model
embedding_model = init_embedding_model() # 初始化模型
embedding_model = init_embedding_model()
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
@ -139,9 +144,9 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
text = qwen2vl_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen2vl_processor(
text=[text],
images=image_inputs,
@ -152,12 +157,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
inputs = inputs.to(device)
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = qwen2vl_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
@ -275,10 +280,14 @@ async def chat_completions(request: ChatCompletionRequest):
if __name__ == "__main__":
import uvicorn
parser = argparse.ArgumentParser(description="Run the server with specified model and port")
parser = argparse.ArgumentParser(
description="Run the server with specified model and port"
)
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args()
if args.florence and args.qwen2vl: