diff --git a/deepface/models/face_detection/MtCnn.py b/deepface/models/face_detection/MtCnn.py index de43b96..b3bb306 100644 --- a/deepface/models/face_detection/MtCnn.py +++ b/deepface/models/face_detection/MtCnn.py @@ -1,4 +1,5 @@ # built-in dependencies +import logging from typing import List, Union # 3rd party dependencies @@ -8,6 +9,8 @@ from mtcnn import MTCNN # project dependencies from deepface.models.Detector import Detector, FacialAreaRegion +logger = logging.getLogger(__name__) + # pylint: disable=too-few-public-methods class MtCnnClient(Detector): """ @@ -16,6 +19,7 @@ class MtCnnClient(Detector): def __init__(self): self.model = MTCNN() + self.supports_batch_detection = self._supports_batch_detection() def detect_faces( self, @@ -42,7 +46,10 @@ class MtCnnClient(Detector): # mtcnn expects RGB but OpenCV read BGR # img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_rgb = [img[:, :, ::-1] for img in img] - detections = self.model.detect_faces(img_rgb) + if self.supports_batch_detection: + detections = self.model.detect_faces(img_rgb) + else: + detections = [self.model.detect_faces(single_img) for single_img in img_rgb] for image_detections in detections: image_resp = [] @@ -72,3 +79,21 @@ class MtCnnClient(Detector): if len(resp) == 1: return resp[0] return resp + + def _supports_batch_detection(self) -> bool: + import mtcnn + try: + mtcnn_version = mtcnn.__version__ + except AttributeError: + try: + import mtcnn.metadata + mtcnn_version = mtcnn.metadata.__version__ + except AttributeError: + logger.warning("Failed to determine mtcnn version") + logger.warning("Fallback to single image detection") + return False + supports_batch_detection = mtcnn_version >= "1.0.0" + if not supports_batch_detection: + logger.warning("MtCnn version is less than 1.0.0, batch detection is not supported") + logger.warning("Fallback to single image detection") + return supports_batch_detection