arcface added

This commit is contained in:
serengil 2020-12-14 15:04:41 +03:00
parent 3e93c47d09
commit 5cfdb3b179
6 changed files with 125 additions and 33 deletions

View File

@ -10,13 +10,6 @@ import tensorflow as tf
tf_version = int(tf.__version__.split(".")[0])
from deepface import DeepFace
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID
from deepface.basemodels.DlibResNet import DlibResNet
from deepface.extendedmodels import Age, Gender, Race, Emotion
#import DeepFace
#from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace
#from extendedmodels import Age, Gender, Race, Emotion
#------------------------------
@ -28,28 +21,29 @@ tic = time.time()
print("Loading Face Recognition Models...")
pbar = tqdm(range(0,6), desc='Loading Face Recognition Models...')
pbar = tqdm(range(0, 6), desc='Loading Face Recognition Models...')
for index in pbar:
if index == 0:
pbar.set_description("Loading VGG-Face")
vggface_model = VGGFace.loadModel()
vggface_model = DeepFace.build_model("VGG-Face")
elif index == 1:
pbar.set_description("Loading OpenFace")
openface_model = OpenFace.loadModel()
openface_model = DeepFace.build_model("OpenFace")
elif index == 2:
pbar.set_description("Loading Google FaceNet")
facenet_model = Facenet.loadModel()
facenet_model = DeepFace.build_model("Facenet")
elif index == 3:
pbar.set_description("Loading Facebook DeepFace")
deepface_model = FbDeepFace.loadModel()
deepface_model = DeepFace.build_model("DeepFace")
elif index == 4:
pbar.set_description("Loading DeepID DeepFace")
deepid_model = DeepID.loadModel()
deepid_model = DeepFace.build_model("DeepID")
elif index == 5:
pbar.set_description("Loading Dlib ResNet DeepFace")
dlib_model = DlibResNet()
pbar.set_description("Loading ArcFace DeepFace")
arcface_model = DeepFace.build_model("ArcFace")
toc = time.time()
print("Face recognition models are built in ", toc-tic," seconds")
@ -65,16 +59,16 @@ pbar = tqdm(range(0,4), desc='Loading Facial Attribute Analysis Models...')
for index in pbar:
if index == 0:
pbar.set_description("Loading emotion analysis model")
emotion_model = Emotion.loadModel()
emotion_model = DeepFace.build_model('Emotion')
elif index == 1:
pbar.set_description("Loading age prediction model")
age_model = Age.loadModel()
age_model = DeepFace.build_model('Age')
elif index == 2:
pbar.set_description("Loading gender prediction model")
gender_model = Gender.loadModel()
gender_model = DeepFace.build_model('Gender')
elif index == 3:
pbar.set_description("Loading race prediction model")
race_model = Race.loadModel()
race_model = DeepFace.build_model('Race')
toc = time.time()
@ -231,19 +225,17 @@ def verifyWrapper(req, trx_id = 0):
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = deepface_model)
elif model_name == "DeepID":
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = deepid_model)
elif model_name == "Dlib":
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = dlib_model)
elif model_name == "ArcFace":
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = arcface_model)
elif model_name == "Ensemble":
models = {}
models["VGG-Face"] = vggface_model
models["Facenet"] = facenet_model
models["OpenFace"] = openface_model
models["DeepFace"] = deepface_model
resp_obj = DeepFace.verify(instances, model_name = model_name, model = models)
else:
resp_obj = jsonify({'success': False, 'error': 'You must pass a valid model name. Available models are VGG-Face, Facenet, OpenFace, DeepFace but you passed %s' % (model_name)}), 205
resp_obj = jsonify({'success': False, 'error': 'You must pass a valid model name. You passed %s' % (model_name)}), 205
return resp_obj

View File

@ -9,7 +9,7 @@ import pandas as pd
from tqdm import tqdm
import pickle
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID, DlibWrapper, Boosting
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID, DlibWrapper, ArcFace, Boosting
from deepface.extendedmodels import Age, Gender, Race, Emotion
from deepface.commons import functions, realtime, distance as dst
@ -33,6 +33,7 @@ def build_model(model_name):
'DeepFace': FbDeepFace.loadModel,
'DeepID': DeepID.loadModel,
'Dlib': DlibWrapper.loadModel,
'ArcFace': ArcFace.loadModel,
'Emotion': Emotion.loadModel,
'Age': Age.loadModel,
'Gender': Gender.loadModel,
@ -62,7 +63,7 @@ def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric =
['img2.jpg', 'img3.jpg']
]
model_name (string): VGG-Face, Facenet, OpenFace, DeepFace, DeepID, Dlib or Ensemble
model_name (string): VGG-Face, Facenet, OpenFace, DeepFace, DeepID, Dlib, ArcFace or Ensemble
distance_metric (string): cosine, euclidean, euclidean_l2

View File

@ -0,0 +1,94 @@
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.lib.io import file_io
import tensorflow
from tensorflow import keras
import os
from pathlib import Path
import gdown
def loadModel():
base_model = ResNet34()
inputs = base_model.inputs[0]
arcface_model = base_model.outputs[0]
arcface_model = keras.layers.BatchNormalization(momentum=0.9, epsilon=2e-5)(arcface_model)
arcface_model = keras.layers.Dropout(0.4)(arcface_model)
arcface_model = keras.layers.Flatten()(arcface_model)
arcface_model = keras.layers.Dense(512, activation=None, use_bias=True, kernel_initializer="glorot_normal")(arcface_model)
embedding = keras.layers.BatchNormalization(momentum=0.9, epsilon=2e-5, name="embedding", scale=True)(arcface_model)
model = keras.models.Model(inputs, embedding, name=base_model.name)
#---------------------------------------
#check the availability of pre-trained weights
home = str(Path.home())
url = "https://drive.google.com/uc?id=1LVB3CdVejpmGHM28BpqqkbZP5hDEcdZY"
file_name = "arcface_weights.h5"
output = home+'/.deepface/weights/'+file_name
if os.path.isfile(output) != True:
print(file_name," will be downloaded to ",output)
gdown.download(url, output, quiet=False)
#---------------------------------------
try:
model.load_weights(output)
except:
print("pre-trained weights could not be loaded.")
print("You might try to download it from the url ", url," and copy to ",output," manually")
return model
def ResNet34():
img_input = tensorflow.keras.layers.Input(shape=(112, 112, 3))
x = tensorflow.keras.layers.ZeroPadding2D(padding=1, name='conv1_pad')(img_input)
x = tensorflow.keras.layers.Conv2D(64, 3, strides=1, use_bias=False, kernel_initializer='glorot_normal', name='conv1_conv')(x)
x = tensorflow.keras.layers.BatchNormalization(axis=3, epsilon=2e-5, momentum=0.9, name='conv1_bn')(x)
x = tensorflow.keras.layers.PReLU(shared_axes=[1, 2], name='conv1_prelu')(x)
x = stack_fn(x)
model = training.Model(img_input, x, name='ResNet34')
return model
def block1(x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None):
bn_axis = 3
if conv_shortcut:
shortcut = tensorflow.keras.layers.Conv2D(filters, 1, strides=stride, use_bias=False, kernel_initializer='glorot_normal', name=name + '_0_conv')(x)
shortcut = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_0_bn')(shortcut)
else:
shortcut = x
x = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_1_bn')(x)
x = tensorflow.keras.layers.ZeroPadding2D(padding=1, name=name + '_1_pad')(x)
x = tensorflow.keras.layers.Conv2D(filters, 3, strides=1, kernel_initializer='glorot_normal', use_bias=False, name=name + '_1_conv')(x)
x = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_2_bn')(x)
x = tensorflow.keras.layers.PReLU(shared_axes=[1, 2], name=name + '_1_prelu')(x)
x = tensorflow.keras.layers.ZeroPadding2D(padding=1, name=name + '_2_pad')(x)
x = tensorflow.keras.layers.Conv2D(filters, kernel_size, strides=stride, kernel_initializer='glorot_normal', use_bias=False, name=name + '_2_conv')(x)
x = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_3_bn')(x)
x = tensorflow.keras.layers.Add(name=name + '_add')([shortcut, x])
return x
def stack1(x, filters, blocks, stride1=2, name=None):
x = block1(x, filters, stride=stride1, name=name + '_block1')
for i in range(2, blocks + 1):
x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i))
return x
def stack_fn(x):
x = stack1(x, 64, 3, name='conv2')
x = stack1(x, 128, 4, name='conv3')
x = stack1(x, 256, 6, name='conv4')
return stack1(x, 512, 3, name='conv5')

View File

@ -70,16 +70,20 @@ def loadModel(url = 'https://drive.google.com/uc?id=1CPSeum3HpopfomUEK1gybeuIVoe
#-----------------------------------
home = str(Path.home())
output = home+'/.deepface/weights/vgg_face_weights.h5'
if os.path.isfile(home+'/.deepface/weights/vgg_face_weights.h5') != True:
if os.path.isfile(output) != True:
print("vgg_face_weights.h5 will be downloaded...")
output = home+'/.deepface/weights/vgg_face_weights.h5'
gdown.download(url, output, quiet=False)
#-----------------------------------
model.load_weights(home+'/.deepface/weights/vgg_face_weights.h5')
try:
model.load_weights(output)
except Exception as err:
print(str(err))
print("Pre-trained weight could not be loaded.")
print("You might try to download the pre-trained weights from the url ", url, " and copy it to the ", output)
#-----------------------------------

View File

@ -25,7 +25,8 @@ def findThreshold(model_name, distance_metric):
'Facenet': {'cosine': 0.40, 'euclidean': 10, 'euclidean_l2': 0.80},
'DeepFace': {'cosine': 0.23, 'euclidean': 64, 'euclidean_l2': 0.64},
'DeepID': {'cosine': 0.015, 'euclidean': 45, 'euclidean_l2': 0.17},
'Dlib': {'cosine': 0.07, 'euclidean': 0.6, 'euclidean_l2': 0.6}
'Dlib': {'cosine': 0.07, 'euclidean': 0.6, 'euclidean_l2': 0.6},
'ArcFace': {'cosine': 0.6871912959056619, 'euclidean': 4.1591468986978075, 'euclidean_l2': 1.1315718048269017}
}
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)

View File

@ -148,7 +148,7 @@ dataset = [
['dataset/img6.jpg', 'dataset/img9.jpg', False],
]
models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace', 'DeepID', 'Dlib']
models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace', 'DeepID', 'Dlib', 'ArcFace']
metrics = ['cosine', 'euclidean', 'euclidean_l2']
passed_tests = 0; test_cases = 0