diff --git a/deepface/detectors/FaceDetector.py b/deepface/detectors/FaceDetector.py index ff30f76..e12500d 100644 --- a/deepface/detectors/FaceDetector.py +++ b/deepface/detectors/FaceDetector.py @@ -6,21 +6,18 @@ from deepface.commons import distance def build_model(detector_backend): - if detector_backend == 'opencv': - face_detector = OpenCvWrapper.build_model() + backends = { + 'opencv': OpenCvWrapper.build_model, + 'ssd': SsdWrapper.build_model, + 'dlib': DlibWrapper.build_model, + 'mtcnn': MtcnnWrapper.build_model, + 'retinaface': RetinaFaceWrapper.build_model + } - elif detector_backend == 'ssd': - face_detector = SsdWrapper.build_model() - - elif detector_backend == 'dlib': - face_detector = DlibWrapper.build_model() - - elif detector_backend == 'mtcnn': - face_detector = MtcnnWrapper.build_model() - - elif detector_backend == 'retinaface': - face_detector = RetinaFaceWrapper.build_model() + face_detector = backends.get(detector_backend) + if face_detector: + face_detector = face_detector() else: raise ValueError("invalid detector_backend passed - " + detector_backend) @@ -28,21 +25,18 @@ def build_model(detector_backend): def detect_face(face_detector, detector_backend, img): - if detector_backend == 'opencv': - face, region = OpenCvWrapper.detect_face(face_detector, img) + backends = { + 'opencv': OpenCvWrapper.detect_face, + 'ssd': SsdWrapper.detect_face, + 'dlib': DlibWrapper.detect_face, + 'mtcnn': MtcnnWrapper.detect_face, + 'retinaface': RetinaFaceWrapper.detect_face + } - elif detector_backend == 'ssd': - face, region = SsdWrapper.detect_face(face_detector, img) - - elif detector_backend == 'dlib': - face, region = DlibWrapper.detect_face(face_detector, img) - - elif detector_backend == 'mtcnn': - face, region = MtcnnWrapper.detect_face(face_detector, img) - - elif detector_backend == 'retinaface': - face, region = RetinaFaceWrapper.detect_face(face_detector, img) + detect_face = backends.get(detector_backend) + if detect_face: + face, region = detect_face(face_detector, img) else: raise ValueError("invalid detector_backend passed - " + detector_backend) diff --git a/deepface/detectors/__init__.py b/deepface/detectors/__init__.py new file mode 100644 index 0000000..e69de29