mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
feat: cuda change
This commit is contained in:
parent
e9d980fdd4
commit
fa0ce75a5f
@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import List, Any
|
||||
import gdown
|
||||
import numpy as np
|
||||
from typing import List, Any
|
||||
from deepface.commons import package_utils, folder_utils
|
||||
from deepface.models.FacialRecognition import FacialRecognition
|
||||
from deepface.commons.logger import Logger
|
||||
@ -1742,15 +1742,17 @@ def load_facenet512d_onnx_model(
|
||||
model (Any)
|
||||
"""
|
||||
try:
|
||||
import torch # https://stackoverflow.com/questions/75267445
|
||||
import onnxruntime as ort
|
||||
except ModuleNotFoundError as e:
|
||||
raise ImportError(
|
||||
"FaceNet512dONNX is an optional model, ensure the library is installed. "
|
||||
"Please install using 'pip install onnxruntime' or 'pip install onnxruntime-gpu' to use gpu"
|
||||
"FaceNet512ONNX is an optional model, ensure the library is installed. "
|
||||
"Please install using 'pip install onnxruntime' or "
|
||||
"'pip install onnxruntime-gpu' to use gpu"
|
||||
) from e
|
||||
|
||||
if ort.get_device() == "GPU":
|
||||
logger.info(f"using onnx GPU for inference")
|
||||
if torch.cuda.is_available():
|
||||
logger.info("using onnx GPU for inference")
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
else:
|
||||
providers = ['CPUExecutionProvider']
|
||||
|
Loading…
x
Reference in New Issue
Block a user