diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 6723185..ae03407 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -35,7 +35,7 @@ def build_model(model_name): built deepface model """ - global model_obj, model_label + global model_obj models = { 'VGG-Face': VGGFace.loadModel, @@ -48,22 +48,22 @@ def build_model(model_name): 'Emotion': Emotion.loadModel, 'Age': Age.loadModel, 'Gender': Gender.loadModel, - 'Race': Race.loadModel, - 'Ensemble': Boosting.loadModel + 'Race': Race.loadModel } - if not "model_obj" in globals() or model_label != model_name: + if not "model_obj" in globals(): + model_obj = {} - model_obj = models.get(model_name) - - if model_obj: - model_obj = model_obj() - model_label = model_name - #print('Using {} model backend'.format(model_name)) + if not model_name in model_obj.keys(): + model = models.get(model_name) + if model: + model = model() + model_obj[model_name] = model + #print(model_name," built") else: raise ValueError('Invalid model_name passed - {}'.format(model_name)) - return model_obj + return model_obj[model_name] def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True, detector_backend = 'opencv', align = True): @@ -123,7 +123,7 @@ def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric = if model == None: if model_name == 'Ensemble': - models = build_model(model_name) + models = Boosting.loadModel() else: model = build_model(model_name) models = {} @@ -502,7 +502,7 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine', if model_name == 'Ensemble': print("Ensemble learning enabled") - models = build_model(model_name) + models = Boosting.loadModel() else: #model is not ensemble model = build_model(model_name) diff --git a/deepface/detectors/FaceDetector.py b/deepface/detectors/FaceDetector.py index 7c43b5a..dcff5dd 100644 --- a/deepface/detectors/FaceDetector.py +++ b/deepface/detectors/FaceDetector.py @@ -6,7 +6,7 @@ from deepface.commons import distance def build_model(detector_backend): - global face_detector_obj, face_detector_label + global face_detector_obj backends = { 'opencv': OpenCvWrapper.build_model, @@ -16,16 +16,20 @@ def build_model(detector_backend): 'retinaface': RetinaFaceWrapper.build_model } - if not "face_detector_obj" in globals() or face_detector_label != detector_backend: - face_detector_obj = backends.get(detector_backend) - face_detector_label = detector_backend + if not "face_detector_obj" in globals(): + face_detector_obj = {} - if face_detector_obj: - face_detector_obj = face_detector_obj() + if not detector_backend in face_detector_obj.keys(): + face_detector = backends.get(detector_backend) + + if face_detector: + face_detector = face_detector() + face_detector_obj[detector_backend] = face_detector + #print(detector_backend," built") else: raise ValueError("invalid detector_backend passed - " + detector_backend) - return face_detector_obj + return face_detector_obj[detector_backend] def detect_face(face_detector, detector_backend, img, align = True): diff --git a/tests/unit_tests.py b/tests/unit_tests.py index bd6c699..85262f4 100644 --- a/tests/unit_tests.py +++ b/tests/unit_tests.py @@ -168,7 +168,7 @@ dataset = [ ['dataset/img1.jpg', 'dataset/img2.jpg', True], ['dataset/img5.jpg', 'dataset/img6.jpg', True], ['dataset/img6.jpg', 'dataset/img7.jpg', True], - ['dataset/img8.jpg', 'dataset/img9.jpg', True], + #['dataset/img8.jpg', 'dataset/img9.jpg', True], ['dataset/img1.jpg', 'dataset/img11.jpg', True], ['dataset/img2.jpg', 'dataset/img11.jpg', True], @@ -238,10 +238,10 @@ else: print("Analyze function with passing pre-trained model") -emotion_model = Emotion.loadModel() -age_model = Age.loadModel() -gender_model = Gender.loadModel() -race_model = Race.loadModel() +emotion_model = DeepFace.build_model("Emotion") +age_model = DeepFace.build_model("Age") +gender_model = DeepFace.build_model("Gender") +race_model = DeepFace.build_model("Race") facial_attribute_models = {} facial_attribute_models["emotion"] = emotion_model @@ -277,9 +277,9 @@ print("--------------------------") print("Pre-trained ensemble method - find") from deepface import DeepFace -from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace +from deepface.basemodels import Boosting -model = DeepFace.build_model("Ensemble") +model = Boosting.loadModel() df = DeepFace.find("dataset/img1.jpg", db_path = "dataset", model_name = 'Ensemble', model = model, enforce_detection=False) print(df)