diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 32f79a8..d69e07b 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -35,6 +35,8 @@ def build_model(model_name): built deepface model """ + global model_obj, model_label + models = { 'VGG-Face': VGGFace.loadModel, 'OpenFace': OpenFace.loadModel, @@ -43,21 +45,24 @@ def build_model(model_name): 'DeepID': DeepID.loadModel, 'Dlib': DlibWrapper.loadModel, 'ArcFace': ArcFace.loadModel, - 'Emotion': Emotion.loadModel, 'Age': Age.loadModel, 'Gender': Gender.loadModel, 'Race': Race.loadModel } - model = models.get(model_name) + if not "model_obj" in globals() or model_label != model_name: - if model: - model = model() - #print('Using {} model backend'.format(model_name)) - return model - else: - raise ValueError('Invalid model_name passed - {}'.format(model_name)) + model_obj = models.get(model_name) + + if model_obj: + model_obj = model_obj() + model_label = model_name + #print('Using {} model backend'.format(model_name)) + else: + raise ValueError('Invalid model_name passed - {}'.format(model_name)) + + return model_obj def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True, detector_backend = 'opencv', align = True): diff --git a/tests/unit_tests.py b/tests/unit_tests.py index 20bfdd3..176154a 100644 --- a/tests/unit_tests.py +++ b/tests/unit_tests.py @@ -186,8 +186,8 @@ metrics = ['cosine', 'euclidean', 'euclidean_l2'] passed_tests = 0; test_cases = 0 for model in models: - prebuilt_model = DeepFace.build_model(model) - print(model," is built") + #prebuilt_model = DeepFace.build_model(model) + #print(model," is built") for metric in metrics: for instance in dataset: img1 = instance[0] @@ -195,7 +195,8 @@ for model in models: result = instance[2] resp_obj = DeepFace.verify(img1, img2 - , model_name = model, model = prebuilt_model + , model_name = model + #, model = prebuilt_model , distance_metric = metric) prediction = resp_obj["verified"]