diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index e5dc176..e4f1f94 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -23,7 +23,7 @@ from deepface.commons import functions, realtime, distance as dst def analyze_init(models = []): #--------------------------------- - + built_models = {} #if a specific target is not passed, then find them all if len(models) == 0: models = ['emotion', 'age', 'gender', 'race'] @@ -33,16 +33,17 @@ def analyze_init(models = []): #--------------------------------- if 'emotion' in models: - emotion_model = Emotion.loadModel() + built_models['emotion'] = Emotion.loadModel() if 'age' in models: - age_model = Age.loadModel() + built_models['age'] = Age.loadModel() if 'gender' in models: - gender_model = Gender.loadModel() + built_models['gender'] = Gender.loadModel() if 'race' in models: - race_model = Race.loadModel() + built_models['race'] = Race.loadModel() + return built_models def verify_init(model_name = 'VGG-Face'): @@ -63,6 +64,7 @@ def verify_init(model_name = 'VGG-Face'): model = FbDeepFace.loadModel() else: raise ValueError("Invalid model_name passed - ", model_name) + return model def verify(img1_path, img2_path='' @@ -193,7 +195,8 @@ def verify(img1_path, img2_path='' return resp_obj #return resp_objects -def analyze(img_path, actions= []): + +def analyze(img_path, actions= [], models= {}): if type(img_path) == list: img_paths = img_path.copy() @@ -213,16 +216,28 @@ def analyze(img_path, actions= []): #--------------------------------- if 'emotion' in actions: - emotion_model = Emotion.loadModel() + if 'emotion' in models: + emotion_model = models['emotion'] + else: + emotion_model = Emotion.loadModel() if 'age' in actions: - age_model = Age.loadModel() + if 'age' in models: + age_model = models['age'] + else: + age_model = Age.loadModel() if 'gender' in actions: - gender_model = Gender.loadModel() + if 'gender' in models: + gender_model = models['gender'] + else: + gender_model = Gender.loadModel() if 'race' in actions: - race_model = Race.loadModel() + if 'race' in models: + race_model = models['race'] + else: + race_model = Race.loadModel() #--------------------------------- resp_objects = []