mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 20:15:21 +00:00
performance improvement
This commit is contained in:
parent
20b64893ff
commit
18c96e492e
@ -35,7 +35,7 @@ def build_model(model_name):
|
|||||||
built deepface model
|
built deepface model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
global model_obj, model_label
|
global model_obj
|
||||||
|
|
||||||
models = {
|
models = {
|
||||||
'VGG-Face': VGGFace.loadModel,
|
'VGG-Face': VGGFace.loadModel,
|
||||||
@ -48,22 +48,22 @@ def build_model(model_name):
|
|||||||
'Emotion': Emotion.loadModel,
|
'Emotion': Emotion.loadModel,
|
||||||
'Age': Age.loadModel,
|
'Age': Age.loadModel,
|
||||||
'Gender': Gender.loadModel,
|
'Gender': Gender.loadModel,
|
||||||
'Race': Race.loadModel,
|
'Race': Race.loadModel
|
||||||
'Ensemble': Boosting.loadModel
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if not "model_obj" in globals() or model_label != model_name:
|
if not "model_obj" in globals():
|
||||||
|
model_obj = {}
|
||||||
|
|
||||||
model_obj = models.get(model_name)
|
if not model_name in model_obj.keys():
|
||||||
|
model = models.get(model_name)
|
||||||
if model_obj:
|
if model:
|
||||||
model_obj = model_obj()
|
model = model()
|
||||||
model_label = model_name
|
model_obj[model_name] = model
|
||||||
#print('Using {} model backend'.format(model_name))
|
#print(model_name," built")
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid model_name passed - {}'.format(model_name))
|
raise ValueError('Invalid model_name passed - {}'.format(model_name))
|
||||||
|
|
||||||
return model_obj
|
return model_obj[model_name]
|
||||||
|
|
||||||
def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True, detector_backend = 'opencv', align = True):
|
def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True, detector_backend = 'opencv', align = True):
|
||||||
|
|
||||||
@ -123,7 +123,7 @@ def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric =
|
|||||||
|
|
||||||
if model == None:
|
if model == None:
|
||||||
if model_name == 'Ensemble':
|
if model_name == 'Ensemble':
|
||||||
models = build_model(model_name)
|
models = Boosting.loadModel()
|
||||||
else:
|
else:
|
||||||
model = build_model(model_name)
|
model = build_model(model_name)
|
||||||
models = {}
|
models = {}
|
||||||
@ -502,7 +502,7 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
|
|||||||
|
|
||||||
if model_name == 'Ensemble':
|
if model_name == 'Ensemble':
|
||||||
print("Ensemble learning enabled")
|
print("Ensemble learning enabled")
|
||||||
models = build_model(model_name)
|
models = Boosting.loadModel()
|
||||||
|
|
||||||
else: #model is not ensemble
|
else: #model is not ensemble
|
||||||
model = build_model(model_name)
|
model = build_model(model_name)
|
||||||
|
@ -6,7 +6,7 @@ from deepface.commons import distance
|
|||||||
|
|
||||||
def build_model(detector_backend):
|
def build_model(detector_backend):
|
||||||
|
|
||||||
global face_detector_obj, face_detector_label
|
global face_detector_obj
|
||||||
|
|
||||||
backends = {
|
backends = {
|
||||||
'opencv': OpenCvWrapper.build_model,
|
'opencv': OpenCvWrapper.build_model,
|
||||||
@ -16,16 +16,20 @@ def build_model(detector_backend):
|
|||||||
'retinaface': RetinaFaceWrapper.build_model
|
'retinaface': RetinaFaceWrapper.build_model
|
||||||
}
|
}
|
||||||
|
|
||||||
if not "face_detector_obj" in globals() or face_detector_label != detector_backend:
|
if not "face_detector_obj" in globals():
|
||||||
face_detector_obj = backends.get(detector_backend)
|
face_detector_obj = {}
|
||||||
face_detector_label = detector_backend
|
|
||||||
|
|
||||||
if face_detector_obj:
|
if not detector_backend in face_detector_obj.keys():
|
||||||
face_detector_obj = face_detector_obj()
|
face_detector = backends.get(detector_backend)
|
||||||
|
|
||||||
|
if face_detector:
|
||||||
|
face_detector = face_detector()
|
||||||
|
face_detector_obj[detector_backend] = face_detector
|
||||||
|
#print(detector_backend," built")
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid detector_backend passed - " + detector_backend)
|
raise ValueError("invalid detector_backend passed - " + detector_backend)
|
||||||
|
|
||||||
return face_detector_obj
|
return face_detector_obj[detector_backend]
|
||||||
|
|
||||||
def detect_face(face_detector, detector_backend, img, align = True):
|
def detect_face(face_detector, detector_backend, img, align = True):
|
||||||
|
|
||||||
|
@ -168,7 +168,7 @@ dataset = [
|
|||||||
['dataset/img1.jpg', 'dataset/img2.jpg', True],
|
['dataset/img1.jpg', 'dataset/img2.jpg', True],
|
||||||
['dataset/img5.jpg', 'dataset/img6.jpg', True],
|
['dataset/img5.jpg', 'dataset/img6.jpg', True],
|
||||||
['dataset/img6.jpg', 'dataset/img7.jpg', True],
|
['dataset/img6.jpg', 'dataset/img7.jpg', True],
|
||||||
['dataset/img8.jpg', 'dataset/img9.jpg', True],
|
#['dataset/img8.jpg', 'dataset/img9.jpg', True],
|
||||||
|
|
||||||
['dataset/img1.jpg', 'dataset/img11.jpg', True],
|
['dataset/img1.jpg', 'dataset/img11.jpg', True],
|
||||||
['dataset/img2.jpg', 'dataset/img11.jpg', True],
|
['dataset/img2.jpg', 'dataset/img11.jpg', True],
|
||||||
@ -238,10 +238,10 @@ else:
|
|||||||
|
|
||||||
print("Analyze function with passing pre-trained model")
|
print("Analyze function with passing pre-trained model")
|
||||||
|
|
||||||
emotion_model = Emotion.loadModel()
|
emotion_model = DeepFace.build_model("Emotion")
|
||||||
age_model = Age.loadModel()
|
age_model = DeepFace.build_model("Age")
|
||||||
gender_model = Gender.loadModel()
|
gender_model = DeepFace.build_model("Gender")
|
||||||
race_model = Race.loadModel()
|
race_model = DeepFace.build_model("Race")
|
||||||
|
|
||||||
facial_attribute_models = {}
|
facial_attribute_models = {}
|
||||||
facial_attribute_models["emotion"] = emotion_model
|
facial_attribute_models["emotion"] = emotion_model
|
||||||
@ -277,9 +277,9 @@ print("--------------------------")
|
|||||||
print("Pre-trained ensemble method - find")
|
print("Pre-trained ensemble method - find")
|
||||||
|
|
||||||
from deepface import DeepFace
|
from deepface import DeepFace
|
||||||
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace
|
from deepface.basemodels import Boosting
|
||||||
|
|
||||||
model = DeepFace.build_model("Ensemble")
|
model = Boosting.loadModel()
|
||||||
df = DeepFace.find("dataset/img1.jpg", db_path = "dataset", model_name = 'Ensemble', model = model, enforce_detection=False)
|
df = DeepFace.find("dataset/img1.jpg", db_path = "dataset", model_name = 'Ensemble', model = model, enforce_detection=False)
|
||||||
|
|
||||||
print(df)
|
print(df)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user