feat(ml): use float32 for pascal

This commit is contained in:
arkohut 2024-09-06 23:42:46 +08:00
parent dbbc2792ef
commit a5562b14eb

View File

@ -25,7 +25,12 @@ elif torch.backends.mps.is_available():
else: else:
device = torch.device("cpu") device = torch.device("cpu")
torch_dtype = "auto" torch_dtype = (
torch.float32
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
print(f"Using device: {device}") print(f"Using device: {device}")
@ -37,7 +42,7 @@ def init_embedding_model():
return model return model
embedding_model = init_embedding_model() # 初始化模型 embedding_model = init_embedding_model()
def generate_embeddings(input_texts: List[str]) -> List[List[float]]: def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
@ -275,10 +280,14 @@ async def chat_completions(request: ChatCompletionRequest):
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
parser = argparse.ArgumentParser(description="Run the server with specified model and port") parser = argparse.ArgumentParser(
description="Run the server with specified model and port"
)
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model") parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args() args = parser.parse_args()
if args.florence and args.qwen2vl: if args.florence and args.qwen2vl: