mirror of
https://github.com/serengil/deepface.git
synced 2025-06-09 04:55:24 +00:00
global model building function created
This commit is contained in:
parent
ba480a323b
commit
005f3bee31
@ -16,9 +16,29 @@ from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID
|
||||
from deepface.extendedmodels import Age, Gender, Race, Emotion
|
||||
from deepface.commons import functions, realtime, distance as dst
|
||||
|
||||
def DlibResNet_():
|
||||
from deepface.basemodels.DlibResNet import DlibResNet
|
||||
return DlibResNet()
|
||||
def build_model(model_name):
|
||||
|
||||
models = {
|
||||
'VGG-Face': VGGFace.loadModel,
|
||||
'OpenFace': OpenFace.loadModel,
|
||||
'Facenet': Facenet.loadModel,
|
||||
'DeepFace': FbDeepFace.loadModel,
|
||||
'DeepID': DeepID.loadModel,
|
||||
'Dlib': DlibResNet_,
|
||||
'Emotion': Emotion.loadModel,
|
||||
'Age': Age.loadModel,
|
||||
'Gender': Gender.loadModel,
|
||||
'Race': Race.loadModel
|
||||
}
|
||||
|
||||
model = models.get(model_name)
|
||||
|
||||
if model:
|
||||
model = model()
|
||||
print('Using {} model backend'.format(model_name))
|
||||
return model
|
||||
else:
|
||||
raise ValueError('Invalid model_name passed - {}'.format(model_name))
|
||||
|
||||
def verify(img1_path, img2_path = '', model_name='VGG-Face', distance_metric='cosine',
|
||||
model=None, enforce_detection=True, detector_backend = 'opencv'):
|
||||
@ -55,16 +75,16 @@ def verify(img1_path, img2_path = '', model_name='VGG-Face', distance_metric='co
|
||||
|
||||
if index == 0:
|
||||
model_pbar.set_description("Loading VGG-Face")
|
||||
model["VGG-Face"] = VGGFace.loadModel()
|
||||
model["VGG-Face"] = build_model('VGG-Face')
|
||||
elif index == 1:
|
||||
model_pbar.set_description("Loading Google FaceNet")
|
||||
model["Facenet"] = Facenet.loadModel()
|
||||
model["Facenet"] = build_model('Facenet')
|
||||
elif index == 2:
|
||||
model_pbar.set_description("Loading OpenFace")
|
||||
model["OpenFace"] = OpenFace.loadModel()
|
||||
model["OpenFace"] = build_model('OpenFace')
|
||||
elif index == 3:
|
||||
model_pbar.set_description("Loading Facebook DeepFace")
|
||||
model["DeepFace"] = FbDeepFace.loadModel()
|
||||
model["DeepFace"] = build_model('DeepFace')
|
||||
|
||||
#--------------------------
|
||||
#validate model dictionary because it might be passed from input as pre-trained
|
||||
@ -204,22 +224,7 @@ def verify(img1_path, img2_path = '', model_name='VGG-Face', distance_metric='co
|
||||
#ensemble learning disabled
|
||||
|
||||
if model == None:
|
||||
|
||||
models = {
|
||||
'VGG-Face': VGGFace.loadModel,
|
||||
'OpenFace': OpenFace.loadModel,
|
||||
'Facenet': Facenet.loadModel,
|
||||
'DeepFace': FbDeepFace.loadModel,
|
||||
'DeepID': DeepID.loadModel,
|
||||
'Dlib': DlibResNet_
|
||||
}
|
||||
|
||||
model = models.get(model_name)
|
||||
if model:
|
||||
model = model()
|
||||
print('Using {} model backend and {} distance'.format(model_name, distance_metric))
|
||||
else:
|
||||
raise ValueError('Invalid model_name passed - {}'.format(model_name))
|
||||
model = build_model(model_name)
|
||||
|
||||
else: #model != None
|
||||
print("Already built model is passed")
|
||||
@ -350,6 +355,12 @@ def analyze(img_path, actions = [], models = {}, enforce_detection = True, detec
|
||||
|
||||
#---------------------------------
|
||||
|
||||
#build mtcnn model once
|
||||
if detector_backend == 'mtcnn':
|
||||
functions.load_mtcnn()
|
||||
|
||||
#---------------------------------
|
||||
|
||||
#if a specific target is not passed, then find them all
|
||||
if len(actions) == 0:
|
||||
actions= ['emotion', 'age', 'gender', 'race']
|
||||
@ -363,28 +374,28 @@ def analyze(img_path, actions = [], models = {}, enforce_detection = True, detec
|
||||
print("already built emotion model is passed")
|
||||
emotion_model = models['emotion']
|
||||
else:
|
||||
emotion_model = Emotion.loadModel()
|
||||
emotion_model = build_model('Emotion')
|
||||
|
||||
if 'age' in actions:
|
||||
if 'age' in models:
|
||||
#print("already built age model is passed")
|
||||
age_model = models['age']
|
||||
else:
|
||||
age_model = Age.loadModel()
|
||||
age_model = build_model('Age')
|
||||
|
||||
if 'gender' in actions:
|
||||
if 'gender' in models:
|
||||
print("already built gender model is passed")
|
||||
gender_model = models['gender']
|
||||
else:
|
||||
gender_model = Gender.loadModel()
|
||||
gender_model = build_model('Gender')
|
||||
|
||||
if 'race' in actions:
|
||||
if 'race' in models:
|
||||
print("already built race model is passed")
|
||||
race_model = models['race']
|
||||
else:
|
||||
race_model = Race.loadModel()
|
||||
race_model = build_model('Race')
|
||||
#---------------------------------
|
||||
|
||||
resp_objects = []
|
||||
@ -508,6 +519,11 @@ def analyze(img_path, actions = [], models = {}, enforce_detection = True, detec
|
||||
#return resp_objects
|
||||
|
||||
def detectFace(img_path, detector_backend = 'opencv'):
|
||||
|
||||
#build mtcnn model once
|
||||
if detector_backend == 'mtcnn':
|
||||
functions.load_mtcnn()
|
||||
|
||||
img = functions.preprocess_face(img = img_path, detector_backend = detector_backend)[0] #preprocess_face returns (1, 224, 224, 3)
|
||||
return img[:, :, ::-1] #bgr to rgb
|
||||
|
||||
@ -525,31 +541,21 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
|
||||
bulkProcess = False
|
||||
img_paths = [img_path]
|
||||
|
||||
#-------------------------------
|
||||
|
||||
#build mtcnn model once
|
||||
if detector_backend == 'mtcnn':
|
||||
functions.load_mtcnn()
|
||||
|
||||
#-------------------------------
|
||||
|
||||
if os.path.isdir(db_path) == True:
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
if model == None:
|
||||
if model_name == 'VGG-Face':
|
||||
print("Using VGG-Face model backend and", distance_metric,"distance.")
|
||||
model = VGGFace.loadModel()
|
||||
elif model_name == 'OpenFace':
|
||||
print("Using OpenFace model backend", distance_metric,"distance.")
|
||||
model = OpenFace.loadModel()
|
||||
elif model_name == 'Facenet':
|
||||
print("Using Facenet model backend", distance_metric,"distance.")
|
||||
model = Facenet.loadModel()
|
||||
elif model_name == 'DeepFace':
|
||||
print("Using FB DeepFace model backend", distance_metric,"distance.")
|
||||
model = FbDeepFace.loadModel()
|
||||
elif model_name == 'DeepID':
|
||||
print("Using DeepID model backend", distance_metric,"distance.")
|
||||
model = DeepID.loadModel()
|
||||
elif model_name == 'Dlib':
|
||||
print("Using Dlib ResNet model backend", distance_metric,"distance.")
|
||||
from deepface.basemodels.DlibResNet import DlibResNet #this is not a must because it is very huge
|
||||
model = DlibResNet()
|
||||
elif model_name == 'Ensemble':
|
||||
|
||||
if model_name == 'Ensemble':
|
||||
print("Ensemble learning enabled")
|
||||
#TODO: include DeepID in ensemble method
|
||||
|
||||
@ -562,19 +568,20 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
|
||||
for index in pbar:
|
||||
if index == 0:
|
||||
pbar.set_description("Loading VGG-Face")
|
||||
models['VGG-Face'] = VGGFace.loadModel()
|
||||
models['VGG-Face'] = build_model('VGG-Face')
|
||||
elif index == 1:
|
||||
pbar.set_description("Loading FaceNet")
|
||||
models['Facenet'] = Facenet.loadModel()
|
||||
models['Facenet'] = build_model('Facenet')
|
||||
elif index == 2:
|
||||
pbar.set_description("Loading OpenFace")
|
||||
models['OpenFace'] = OpenFace.loadModel()
|
||||
models['OpenFace'] = build_model('OpenFace')
|
||||
elif index == 3:
|
||||
pbar.set_description("Loading DeepFace")
|
||||
models['DeepFace'] = FbDeepFace.loadModel()
|
||||
models['DeepFace'] = build_model('DeepFace')
|
||||
|
||||
else: #model is not ensemble
|
||||
model = build_model(model_name)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid model_name passed - ", model_name)
|
||||
else: #model != None
|
||||
print("Already built model is passed")
|
||||
|
||||
@ -864,6 +871,10 @@ def allocateMemory():
|
||||
print("Analyzing your system...")
|
||||
functions.allocateMemory()
|
||||
|
||||
def DlibResNet_():
|
||||
#this is not a regular Keras model.
|
||||
from deepface.basemodels.DlibResNet import DlibResNet
|
||||
return DlibResNet()
|
||||
#---------------------------
|
||||
#main
|
||||
|
||||
|
@ -244,8 +244,8 @@ def detect_face(img, detector_backend = 'opencv', grayscale = False, enforce_det
|
||||
|
||||
elif detector_backend == 'mtcnn':
|
||||
|
||||
# mtcnn_detector = MTCNN()
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
# mtcnn_detector = MTCNN() #this is a global variable now
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #mtcnn expects RGB but OpenCV read BGR
|
||||
detections = mtcnn_detector.detect_faces(img_rgb)
|
||||
|
||||
if len(detections) > 0:
|
||||
@ -399,8 +399,8 @@ def align_face(img, detector_backend = 'opencv'):
|
||||
|
||||
elif detector_backend == 'mtcnn':
|
||||
|
||||
# mtcnn_detector = MTCNN()
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
# mtcnn_detector = MTCNN() #this is a global variable now
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #mtcnn expects RGB but OpenCV read BGR
|
||||
detections = mtcnn_detector.detect_faces(img_rgb)
|
||||
|
||||
if len(detections) > 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user