deepid model added

This commit is contained in:
Şefik Serangil 2020-06-16 23:48:18 +03:00
parent cdde9ae160
commit 79e481ee7c
4 changed files with 86 additions and 7 deletions

View File

@ -15,7 +15,7 @@ import tensorflow as tf
import pickle import pickle
from deepface import DeepFace from deepface import DeepFace
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID
from deepface.extendedmodels import Age, Gender, Race, Emotion from deepface.extendedmodels import Age, Gender, Race, Emotion
from deepface.commons import functions, realtime, distance as dst from deepface.commons import functions, realtime, distance as dst
@ -196,6 +196,10 @@ def verify(img1_path, img2_path=''
print("Using FB DeepFace model backend", distance_metric,"distance.") print("Using FB DeepFace model backend", distance_metric,"distance.")
model = FbDeepFace.loadModel() model = FbDeepFace.loadModel()
if model_name == 'DeepID':
print("Using DeepID2 model backend", distance_metric,"distance.")
model = DeepID.loadModel()
else: else:
raise ValueError("Invalid model_name passed - ", model_name) raise ValueError("Invalid model_name passed - ", model_name)
else: #model != None else: #model != None
@ -205,6 +209,9 @@ def verify(img1_path, img2_path=''
#face recognition models have different size of inputs #face recognition models have different size of inputs
input_shape = model.layers[0].input_shape[1:3] input_shape = model.layers[0].input_shape[1:3]
input_shape_x = input_shape[0]
input_shape_y = input_shape[1]
#------------------------------ #------------------------------
#tuned thresholds for model and metric pair #tuned thresholds for model and metric pair
@ -225,8 +232,8 @@ def verify(img1_path, img2_path=''
#---------------------- #----------------------
#crop and align faces #crop and align faces
img1 = functions.detectFace(img1_path, input_shape, enforce_detection = enforce_detection) img1 = functions.detectFace(img1_path, (input_shape_y, input_shape_x), enforce_detection = enforce_detection)
img2 = functions.detectFace(img2_path, input_shape, enforce_detection = enforce_detection) img2 = functions.detectFace(img2_path, (input_shape_y, input_shape_x), enforce_detection = enforce_detection)
#---------------------- #----------------------
#find embeddings #find embeddings
@ -499,9 +506,13 @@ def find(img_path, db_path
elif model_name == 'DeepFace': elif model_name == 'DeepFace':
print("Using FB DeepFace model backend", distance_metric,"distance.") print("Using FB DeepFace model backend", distance_metric,"distance.")
model = FbDeepFace.loadModel() model = FbDeepFace.loadModel()
elif model_name == 'DeepID':
print("Using DeepID model backend", distance_metric,"distance.")
model = DeepID.loadModel()
elif model_name == 'Ensemble': elif model_name == 'Ensemble':
print("Ensemble learning enabled") print("Ensemble learning enabled")
#TODO: include DeepID in ensemble method
import lightgbm as lgb #lightgbm==2.3.1 import lightgbm as lgb #lightgbm==2.3.1
@ -585,7 +596,9 @@ def find(img_path, db_path
if model_name != 'Ensemble': if model_name != 'Ensemble':
input_shape = model.layers[0].input_shape[1:3] input_shape = model.layers[0].input_shape[1:3]
img = functions.detectFace(employee, input_shape, enforce_detection = enforce_detection) input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
img = functions.detectFace(employee, (input_shape_y, input_shape_x), enforce_detection = enforce_detection)
representation = model.predict(img)[0,:] representation = model.predict(img)[0,:]
instance = [] instance = []
@ -600,7 +613,9 @@ def find(img_path, db_path
for j in model_names: for j in model_names:
model = models[j] model = models[j]
input_shape = model.layers[0].input_shape[1:3] input_shape = model.layers[0].input_shape[1:3]
img = functions.detectFace(employee, input_shape, enforce_detection = enforce_detection) input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
img = functions.detectFace(employee, (input_shape_y, input_shape_x), enforce_detection = enforce_detection)
representation = model.predict(img)[0,:] representation = model.predict(img)[0,:]
instance.append(representation) instance.append(representation)
@ -705,7 +720,9 @@ def find(img_path, db_path
if model_name != 'Ensemble': if model_name != 'Ensemble':
input_shape = model.layers[0].input_shape[1:3] input_shape = model.layers[0].input_shape[1:3]
img = functions.detectFace(img_path, input_shape, enforce_detection = enforce_detection) input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
img = functions.detectFace(img_path, (input_shape_y, input_shape_x), enforce_detection = enforce_detection)
target_representation = model.predict(img)[0,:] target_representation = model.predict(img)[0,:]
distances = [] distances = []

View File

@ -0,0 +1,52 @@
import os
from pathlib import Path
import gdown
import keras
from keras.models import Model
from keras.layers import Conv2D, Activation, Input, Add, MaxPooling2D, Flatten, Dense, Dropout
import zipfile
#-------------------------------------
def loadModel():
myInput = Input(shape=(55, 47, 3))
x = Conv2D(20, (4, 4), name='Conv1', activation='relu', input_shape=(55, 47, 3))(myInput)
x = MaxPooling2D(pool_size=2, strides=2, name='Pool1')(x)
x = Dropout(rate=1, name='D1')(x)
x = Conv2D(40, (3, 3), name='Conv2', activation='relu')(x)
x = MaxPooling2D(pool_size=2, strides=2, name='Pool2')(x)
x = Dropout(rate=1, name='D2')(x)
x = Conv2D(60, (3, 3), name='Conv3', activation='relu')(x)
x = MaxPooling2D(pool_size=2, strides=2, name='Pool3')(x)
x = Dropout(rate=1, name='D3')(x)
x1 = Flatten()(x)
fc11 = Dense(160, name = 'fc11')(x1)
x2 = Conv2D(80, (2, 2), name='Conv4', activation='relu')(x)
x2 = Flatten()(x2)
fc12 = Dense(160, name = 'fc12')(x2)
y = Add()([fc11, fc12])
y = Activation('relu', name = 'deepid')(y)
model = Model(inputs=[myInput], outputs=y)
#---------------------------------
home = str(Path.home())
if os.path.isfile(home+'/.deepface/weights/deepid_keras_weights.h5') != True:
print("deepid_keras_weights.h5 will be downloaded...")
url = 'https://drive.google.com/uc?id=1uRLtBCTQQAvHJ_KVrdbRJiCKxU8m5q2J'
output = home+'/.deepface/weights/deepid_keras_weights.h5'
gdown.download(url, output, quiet=False)
model.load_weights(home+'/.deepface/weights/deepid_keras_weights.h5')
return model

View File

@ -129,6 +129,14 @@ def findThreshold(model_name, distance_metric):
elif distance_metric == 'euclidean_l2': elif distance_metric == 'euclidean_l2':
threshold = 0.64 threshold = 0.64
elif model_name == 'DeepID':
if distance_metric == 'cosine':
threshold = 0.015
elif distance_metric == 'euclidean':
threshold = 45
elif distance_metric == 'euclidean_l2':
threshold = 0.17
return threshold return threshold
def get_opencv_path(): def get_opencv_path():

View File

@ -105,7 +105,8 @@ dataset = [
['dataset/img6.jpg', 'dataset/img9.jpg', False], ['dataset/img6.jpg', 'dataset/img9.jpg', False],
] ]
models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace'] models = ['DeepID']
#models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace', 'DeepID']
metrics = ['cosine', 'euclidean', 'euclidean_l2'] metrics = ['cosine', 'euclidean', 'euclidean_l2']
passed_tests = 0; test_cases = 0 passed_tests = 0; test_cases = 0
@ -178,3 +179,4 @@ facial_attribute_models["gender"] = gender_model
facial_attribute_models["race"] = race_model facial_attribute_models["race"] = race_model
resp_obj = DeepFace.analyze("dataset/img1.jpg", models=facial_attribute_models) resp_obj = DeepFace.analyze("dataset/img1.jpg", models=facial_attribute_models)