performance improvement

This commit is contained in:
Sefik Ilkin Serengil 2021-06-24 09:30:45 +03:00
parent 20b64893ff
commit 18c96e492e
3 changed files with 31 additions and 27 deletions

View File

@ -35,7 +35,7 @@ def build_model(model_name):
built deepface model
"""
global model_obj, model_label
global model_obj
models = {
'VGG-Face': VGGFace.loadModel,
@ -48,22 +48,22 @@ def build_model(model_name):
'Emotion': Emotion.loadModel,
'Age': Age.loadModel,
'Gender': Gender.loadModel,
'Race': Race.loadModel,
'Ensemble': Boosting.loadModel
'Race': Race.loadModel
}
if not "model_obj" in globals() or model_label != model_name:
if not "model_obj" in globals():
model_obj = {}
model_obj = models.get(model_name)
if model_obj:
model_obj = model_obj()
model_label = model_name
#print('Using {} model backend'.format(model_name))
if not model_name in model_obj.keys():
model = models.get(model_name)
if model:
model = model()
model_obj[model_name] = model
#print(model_name," built")
else:
raise ValueError('Invalid model_name passed - {}'.format(model_name))
return model_obj
return model_obj[model_name]
def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True, detector_backend = 'opencv', align = True):
@ -123,7 +123,7 @@ def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric =
if model == None:
if model_name == 'Ensemble':
models = build_model(model_name)
models = Boosting.loadModel()
else:
model = build_model(model_name)
models = {}
@ -502,7 +502,7 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
if model_name == 'Ensemble':
print("Ensemble learning enabled")
models = build_model(model_name)
models = Boosting.loadModel()
else: #model is not ensemble
model = build_model(model_name)

View File

@ -6,7 +6,7 @@ from deepface.commons import distance
def build_model(detector_backend):
global face_detector_obj, face_detector_label
global face_detector_obj
backends = {
'opencv': OpenCvWrapper.build_model,
@ -16,16 +16,20 @@ def build_model(detector_backend):
'retinaface': RetinaFaceWrapper.build_model
}
if not "face_detector_obj" in globals() or face_detector_label != detector_backend:
face_detector_obj = backends.get(detector_backend)
face_detector_label = detector_backend
if not "face_detector_obj" in globals():
face_detector_obj = {}
if face_detector_obj:
face_detector_obj = face_detector_obj()
if not detector_backend in face_detector_obj.keys():
face_detector = backends.get(detector_backend)
if face_detector:
face_detector = face_detector()
face_detector_obj[detector_backend] = face_detector
#print(detector_backend," built")
else:
raise ValueError("invalid detector_backend passed - " + detector_backend)
return face_detector_obj
return face_detector_obj[detector_backend]
def detect_face(face_detector, detector_backend, img, align = True):

View File

@ -168,7 +168,7 @@ dataset = [
['dataset/img1.jpg', 'dataset/img2.jpg', True],
['dataset/img5.jpg', 'dataset/img6.jpg', True],
['dataset/img6.jpg', 'dataset/img7.jpg', True],
['dataset/img8.jpg', 'dataset/img9.jpg', True],
#['dataset/img8.jpg', 'dataset/img9.jpg', True],
['dataset/img1.jpg', 'dataset/img11.jpg', True],
['dataset/img2.jpg', 'dataset/img11.jpg', True],
@ -238,10 +238,10 @@ else:
print("Analyze function with passing pre-trained model")
emotion_model = Emotion.loadModel()
age_model = Age.loadModel()
gender_model = Gender.loadModel()
race_model = Race.loadModel()
emotion_model = DeepFace.build_model("Emotion")
age_model = DeepFace.build_model("Age")
gender_model = DeepFace.build_model("Gender")
race_model = DeepFace.build_model("Race")
facial_attribute_models = {}
facial_attribute_models["emotion"] = emotion_model
@ -277,9 +277,9 @@ print("--------------------------")
print("Pre-trained ensemble method - find")
from deepface import DeepFace
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace
from deepface.basemodels import Boosting
model = DeepFace.build_model("Ensemble")
model = Boosting.loadModel()
df = DeepFace.find("dataset/img1.jpg", db_path = "dataset", model_name = 'Ensemble', model = model, enforce_detection=False)
print(df)