mirror of
https://github.com/serengil/deepface.git
synced 2025-06-09 12:57:08 +00:00
global model building function created
This commit is contained in:
parent
ba480a323b
commit
005f3bee31
@ -16,9 +16,29 @@ from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID
|
|||||||
from deepface.extendedmodels import Age, Gender, Race, Emotion
|
from deepface.extendedmodels import Age, Gender, Race, Emotion
|
||||||
from deepface.commons import functions, realtime, distance as dst
|
from deepface.commons import functions, realtime, distance as dst
|
||||||
|
|
||||||
def DlibResNet_():
|
def build_model(model_name):
|
||||||
from deepface.basemodels.DlibResNet import DlibResNet
|
|
||||||
return DlibResNet()
|
models = {
|
||||||
|
'VGG-Face': VGGFace.loadModel,
|
||||||
|
'OpenFace': OpenFace.loadModel,
|
||||||
|
'Facenet': Facenet.loadModel,
|
||||||
|
'DeepFace': FbDeepFace.loadModel,
|
||||||
|
'DeepID': DeepID.loadModel,
|
||||||
|
'Dlib': DlibResNet_,
|
||||||
|
'Emotion': Emotion.loadModel,
|
||||||
|
'Age': Age.loadModel,
|
||||||
|
'Gender': Gender.loadModel,
|
||||||
|
'Race': Race.loadModel
|
||||||
|
}
|
||||||
|
|
||||||
|
model = models.get(model_name)
|
||||||
|
|
||||||
|
if model:
|
||||||
|
model = model()
|
||||||
|
print('Using {} model backend'.format(model_name))
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid model_name passed - {}'.format(model_name))
|
||||||
|
|
||||||
def verify(img1_path, img2_path = '', model_name='VGG-Face', distance_metric='cosine',
|
def verify(img1_path, img2_path = '', model_name='VGG-Face', distance_metric='cosine',
|
||||||
model=None, enforce_detection=True, detector_backend = 'opencv'):
|
model=None, enforce_detection=True, detector_backend = 'opencv'):
|
||||||
@ -55,16 +75,16 @@ def verify(img1_path, img2_path = '', model_name='VGG-Face', distance_metric='co
|
|||||||
|
|
||||||
if index == 0:
|
if index == 0:
|
||||||
model_pbar.set_description("Loading VGG-Face")
|
model_pbar.set_description("Loading VGG-Face")
|
||||||
model["VGG-Face"] = VGGFace.loadModel()
|
model["VGG-Face"] = build_model('VGG-Face')
|
||||||
elif index == 1:
|
elif index == 1:
|
||||||
model_pbar.set_description("Loading Google FaceNet")
|
model_pbar.set_description("Loading Google FaceNet")
|
||||||
model["Facenet"] = Facenet.loadModel()
|
model["Facenet"] = build_model('Facenet')
|
||||||
elif index == 2:
|
elif index == 2:
|
||||||
model_pbar.set_description("Loading OpenFace")
|
model_pbar.set_description("Loading OpenFace")
|
||||||
model["OpenFace"] = OpenFace.loadModel()
|
model["OpenFace"] = build_model('OpenFace')
|
||||||
elif index == 3:
|
elif index == 3:
|
||||||
model_pbar.set_description("Loading Facebook DeepFace")
|
model_pbar.set_description("Loading Facebook DeepFace")
|
||||||
model["DeepFace"] = FbDeepFace.loadModel()
|
model["DeepFace"] = build_model('DeepFace')
|
||||||
|
|
||||||
#--------------------------
|
#--------------------------
|
||||||
#validate model dictionary because it might be passed from input as pre-trained
|
#validate model dictionary because it might be passed from input as pre-trained
|
||||||
@ -204,22 +224,7 @@ def verify(img1_path, img2_path = '', model_name='VGG-Face', distance_metric='co
|
|||||||
#ensemble learning disabled
|
#ensemble learning disabled
|
||||||
|
|
||||||
if model == None:
|
if model == None:
|
||||||
|
model = build_model(model_name)
|
||||||
models = {
|
|
||||||
'VGG-Face': VGGFace.loadModel,
|
|
||||||
'OpenFace': OpenFace.loadModel,
|
|
||||||
'Facenet': Facenet.loadModel,
|
|
||||||
'DeepFace': FbDeepFace.loadModel,
|
|
||||||
'DeepID': DeepID.loadModel,
|
|
||||||
'Dlib': DlibResNet_
|
|
||||||
}
|
|
||||||
|
|
||||||
model = models.get(model_name)
|
|
||||||
if model:
|
|
||||||
model = model()
|
|
||||||
print('Using {} model backend and {} distance'.format(model_name, distance_metric))
|
|
||||||
else:
|
|
||||||
raise ValueError('Invalid model_name passed - {}'.format(model_name))
|
|
||||||
|
|
||||||
else: #model != None
|
else: #model != None
|
||||||
print("Already built model is passed")
|
print("Already built model is passed")
|
||||||
@ -350,6 +355,12 @@ def analyze(img_path, actions = [], models = {}, enforce_detection = True, detec
|
|||||||
|
|
||||||
#---------------------------------
|
#---------------------------------
|
||||||
|
|
||||||
|
#build mtcnn model once
|
||||||
|
if detector_backend == 'mtcnn':
|
||||||
|
functions.load_mtcnn()
|
||||||
|
|
||||||
|
#---------------------------------
|
||||||
|
|
||||||
#if a specific target is not passed, then find them all
|
#if a specific target is not passed, then find them all
|
||||||
if len(actions) == 0:
|
if len(actions) == 0:
|
||||||
actions= ['emotion', 'age', 'gender', 'race']
|
actions= ['emotion', 'age', 'gender', 'race']
|
||||||
@ -363,28 +374,28 @@ def analyze(img_path, actions = [], models = {}, enforce_detection = True, detec
|
|||||||
print("already built emotion model is passed")
|
print("already built emotion model is passed")
|
||||||
emotion_model = models['emotion']
|
emotion_model = models['emotion']
|
||||||
else:
|
else:
|
||||||
emotion_model = Emotion.loadModel()
|
emotion_model = build_model('Emotion')
|
||||||
|
|
||||||
if 'age' in actions:
|
if 'age' in actions:
|
||||||
if 'age' in models:
|
if 'age' in models:
|
||||||
#print("already built age model is passed")
|
#print("already built age model is passed")
|
||||||
age_model = models['age']
|
age_model = models['age']
|
||||||
else:
|
else:
|
||||||
age_model = Age.loadModel()
|
age_model = build_model('Age')
|
||||||
|
|
||||||
if 'gender' in actions:
|
if 'gender' in actions:
|
||||||
if 'gender' in models:
|
if 'gender' in models:
|
||||||
print("already built gender model is passed")
|
print("already built gender model is passed")
|
||||||
gender_model = models['gender']
|
gender_model = models['gender']
|
||||||
else:
|
else:
|
||||||
gender_model = Gender.loadModel()
|
gender_model = build_model('Gender')
|
||||||
|
|
||||||
if 'race' in actions:
|
if 'race' in actions:
|
||||||
if 'race' in models:
|
if 'race' in models:
|
||||||
print("already built race model is passed")
|
print("already built race model is passed")
|
||||||
race_model = models['race']
|
race_model = models['race']
|
||||||
else:
|
else:
|
||||||
race_model = Race.loadModel()
|
race_model = build_model('Race')
|
||||||
#---------------------------------
|
#---------------------------------
|
||||||
|
|
||||||
resp_objects = []
|
resp_objects = []
|
||||||
@ -508,6 +519,11 @@ def analyze(img_path, actions = [], models = {}, enforce_detection = True, detec
|
|||||||
#return resp_objects
|
#return resp_objects
|
||||||
|
|
||||||
def detectFace(img_path, detector_backend = 'opencv'):
|
def detectFace(img_path, detector_backend = 'opencv'):
|
||||||
|
|
||||||
|
#build mtcnn model once
|
||||||
|
if detector_backend == 'mtcnn':
|
||||||
|
functions.load_mtcnn()
|
||||||
|
|
||||||
img = functions.preprocess_face(img = img_path, detector_backend = detector_backend)[0] #preprocess_face returns (1, 224, 224, 3)
|
img = functions.preprocess_face(img = img_path, detector_backend = detector_backend)[0] #preprocess_face returns (1, 224, 224, 3)
|
||||||
return img[:, :, ::-1] #bgr to rgb
|
return img[:, :, ::-1] #bgr to rgb
|
||||||
|
|
||||||
@ -525,31 +541,21 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
|
|||||||
bulkProcess = False
|
bulkProcess = False
|
||||||
img_paths = [img_path]
|
img_paths = [img_path]
|
||||||
|
|
||||||
|
#-------------------------------
|
||||||
|
|
||||||
|
#build mtcnn model once
|
||||||
|
if detector_backend == 'mtcnn':
|
||||||
|
functions.load_mtcnn()
|
||||||
|
|
||||||
|
#-------------------------------
|
||||||
|
|
||||||
if os.path.isdir(db_path) == True:
|
if os.path.isdir(db_path) == True:
|
||||||
|
|
||||||
#---------------------------------------
|
#---------------------------------------
|
||||||
|
|
||||||
if model == None:
|
if model == None:
|
||||||
if model_name == 'VGG-Face':
|
|
||||||
print("Using VGG-Face model backend and", distance_metric,"distance.")
|
if model_name == 'Ensemble':
|
||||||
model = VGGFace.loadModel()
|
|
||||||
elif model_name == 'OpenFace':
|
|
||||||
print("Using OpenFace model backend", distance_metric,"distance.")
|
|
||||||
model = OpenFace.loadModel()
|
|
||||||
elif model_name == 'Facenet':
|
|
||||||
print("Using Facenet model backend", distance_metric,"distance.")
|
|
||||||
model = Facenet.loadModel()
|
|
||||||
elif model_name == 'DeepFace':
|
|
||||||
print("Using FB DeepFace model backend", distance_metric,"distance.")
|
|
||||||
model = FbDeepFace.loadModel()
|
|
||||||
elif model_name == 'DeepID':
|
|
||||||
print("Using DeepID model backend", distance_metric,"distance.")
|
|
||||||
model = DeepID.loadModel()
|
|
||||||
elif model_name == 'Dlib':
|
|
||||||
print("Using Dlib ResNet model backend", distance_metric,"distance.")
|
|
||||||
from deepface.basemodels.DlibResNet import DlibResNet #this is not a must because it is very huge
|
|
||||||
model = DlibResNet()
|
|
||||||
elif model_name == 'Ensemble':
|
|
||||||
print("Ensemble learning enabled")
|
print("Ensemble learning enabled")
|
||||||
#TODO: include DeepID in ensemble method
|
#TODO: include DeepID in ensemble method
|
||||||
|
|
||||||
@ -562,19 +568,20 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
|
|||||||
for index in pbar:
|
for index in pbar:
|
||||||
if index == 0:
|
if index == 0:
|
||||||
pbar.set_description("Loading VGG-Face")
|
pbar.set_description("Loading VGG-Face")
|
||||||
models['VGG-Face'] = VGGFace.loadModel()
|
models['VGG-Face'] = build_model('VGG-Face')
|
||||||
elif index == 1:
|
elif index == 1:
|
||||||
pbar.set_description("Loading FaceNet")
|
pbar.set_description("Loading FaceNet")
|
||||||
models['Facenet'] = Facenet.loadModel()
|
models['Facenet'] = build_model('Facenet')
|
||||||
elif index == 2:
|
elif index == 2:
|
||||||
pbar.set_description("Loading OpenFace")
|
pbar.set_description("Loading OpenFace")
|
||||||
models['OpenFace'] = OpenFace.loadModel()
|
models['OpenFace'] = build_model('OpenFace')
|
||||||
elif index == 3:
|
elif index == 3:
|
||||||
pbar.set_description("Loading DeepFace")
|
pbar.set_description("Loading DeepFace")
|
||||||
models['DeepFace'] = FbDeepFace.loadModel()
|
models['DeepFace'] = build_model('DeepFace')
|
||||||
|
|
||||||
|
else: #model is not ensemble
|
||||||
|
model = build_model(model_name)
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid model_name passed - ", model_name)
|
|
||||||
else: #model != None
|
else: #model != None
|
||||||
print("Already built model is passed")
|
print("Already built model is passed")
|
||||||
|
|
||||||
@ -864,6 +871,10 @@ def allocateMemory():
|
|||||||
print("Analyzing your system...")
|
print("Analyzing your system...")
|
||||||
functions.allocateMemory()
|
functions.allocateMemory()
|
||||||
|
|
||||||
|
def DlibResNet_():
|
||||||
|
#this is not a regular Keras model.
|
||||||
|
from deepface.basemodels.DlibResNet import DlibResNet
|
||||||
|
return DlibResNet()
|
||||||
#---------------------------
|
#---------------------------
|
||||||
#main
|
#main
|
||||||
|
|
||||||
|
@ -244,8 +244,8 @@ def detect_face(img, detector_backend = 'opencv', grayscale = False, enforce_det
|
|||||||
|
|
||||||
elif detector_backend == 'mtcnn':
|
elif detector_backend == 'mtcnn':
|
||||||
|
|
||||||
# mtcnn_detector = MTCNN()
|
# mtcnn_detector = MTCNN() #this is a global variable now
|
||||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #mtcnn expects RGB but OpenCV read BGR
|
||||||
detections = mtcnn_detector.detect_faces(img_rgb)
|
detections = mtcnn_detector.detect_faces(img_rgb)
|
||||||
|
|
||||||
if len(detections) > 0:
|
if len(detections) > 0:
|
||||||
@ -399,8 +399,8 @@ def align_face(img, detector_backend = 'opencv'):
|
|||||||
|
|
||||||
elif detector_backend == 'mtcnn':
|
elif detector_backend == 'mtcnn':
|
||||||
|
|
||||||
# mtcnn_detector = MTCNN()
|
# mtcnn_detector = MTCNN() #this is a global variable now
|
||||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #mtcnn expects RGB but OpenCV read BGR
|
||||||
detections = mtcnn_detector.detect_faces(img_rgb)
|
detections = mtcnn_detector.detect_faces(img_rgb)
|
||||||
|
|
||||||
if len(detections) > 0:
|
if len(detections) > 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user