diff --git a/api/api.py b/api/api.py index 7caf6bb..7488696 100644 --- a/api/api.py +++ b/api/api.py @@ -10,13 +10,6 @@ import tensorflow as tf tf_version = int(tf.__version__.split(".")[0]) from deepface import DeepFace -from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID -from deepface.basemodels.DlibResNet import DlibResNet -from deepface.extendedmodels import Age, Gender, Race, Emotion - -#import DeepFace -#from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace -#from extendedmodels import Age, Gender, Race, Emotion #------------------------------ @@ -28,28 +21,29 @@ tic = time.time() print("Loading Face Recognition Models...") -pbar = tqdm(range(0,6), desc='Loading Face Recognition Models...') +pbar = tqdm(range(0, 6), desc='Loading Face Recognition Models...') for index in pbar: + if index == 0: pbar.set_description("Loading VGG-Face") - vggface_model = VGGFace.loadModel() + vggface_model = DeepFace.build_model("VGG-Face") elif index == 1: pbar.set_description("Loading OpenFace") - openface_model = OpenFace.loadModel() + openface_model = DeepFace.build_model("OpenFace") elif index == 2: pbar.set_description("Loading Google FaceNet") - facenet_model = Facenet.loadModel() + facenet_model = DeepFace.build_model("Facenet") elif index == 3: pbar.set_description("Loading Facebook DeepFace") - deepface_model = FbDeepFace.loadModel() + deepface_model = DeepFace.build_model("DeepFace") elif index == 4: pbar.set_description("Loading DeepID DeepFace") - deepid_model = DeepID.loadModel() + deepid_model = DeepFace.build_model("DeepID") elif index == 5: - pbar.set_description("Loading Dlib ResNet DeepFace") - dlib_model = DlibResNet() - + pbar.set_description("Loading ArcFace DeepFace") + arcface_model = DeepFace.build_model("ArcFace") + toc = time.time() print("Face recognition models are built in ", toc-tic," seconds") @@ -65,16 +59,16 @@ pbar = tqdm(range(0,4), desc='Loading Facial Attribute Analysis Models...') for index in pbar: if index == 0: pbar.set_description("Loading emotion analysis model") - emotion_model = Emotion.loadModel() + emotion_model = DeepFace.build_model('Emotion') elif index == 1: pbar.set_description("Loading age prediction model") - age_model = Age.loadModel() + age_model = DeepFace.build_model('Age') elif index == 2: pbar.set_description("Loading gender prediction model") - gender_model = Gender.loadModel() + gender_model = DeepFace.build_model('Gender') elif index == 3: pbar.set_description("Loading race prediction model") - race_model = Race.loadModel() + race_model = DeepFace.build_model('Race') toc = time.time() @@ -231,19 +225,17 @@ def verifyWrapper(req, trx_id = 0): resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = deepface_model) elif model_name == "DeepID": resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = deepid_model) - elif model_name == "Dlib": - resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = dlib_model) + elif model_name == "ArcFace": + resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = arcface_model) elif model_name == "Ensemble": models = {} models["VGG-Face"] = vggface_model models["Facenet"] = facenet_model models["OpenFace"] = openface_model models["DeepFace"] = deepface_model - resp_obj = DeepFace.verify(instances, model_name = model_name, model = models) - else: - resp_obj = jsonify({'success': False, 'error': 'You must pass a valid model name. Available models are VGG-Face, Facenet, OpenFace, DeepFace but you passed %s' % (model_name)}), 205 + resp_obj = jsonify({'success': False, 'error': 'You must pass a valid model name. You passed %s' % (model_name)}), 205 return resp_obj diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index e1e594c..476c69c 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -9,7 +9,7 @@ import pandas as pd from tqdm import tqdm import pickle -from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID, DlibWrapper, Boosting +from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID, DlibWrapper, ArcFace, Boosting from deepface.extendedmodels import Age, Gender, Race, Emotion from deepface.commons import functions, realtime, distance as dst @@ -33,6 +33,7 @@ def build_model(model_name): 'DeepFace': FbDeepFace.loadModel, 'DeepID': DeepID.loadModel, 'Dlib': DlibWrapper.loadModel, + 'ArcFace': ArcFace.loadModel, 'Emotion': Emotion.loadModel, 'Age': Age.loadModel, 'Gender': Gender.loadModel, @@ -62,7 +63,7 @@ def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric = ['img2.jpg', 'img3.jpg'] ] - model_name (string): VGG-Face, Facenet, OpenFace, DeepFace, DeepID, Dlib or Ensemble + model_name (string): VGG-Face, Facenet, OpenFace, DeepFace, DeepID, Dlib, ArcFace or Ensemble distance_metric (string): cosine, euclidean, euclidean_l2 diff --git a/deepface/basemodels/ArcFace.py b/deepface/basemodels/ArcFace.py new file mode 100644 index 0000000..2ba0b77 --- /dev/null +++ b/deepface/basemodels/ArcFace.py @@ -0,0 +1,94 @@ +from tensorflow.python.keras import backend +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.utils import data_utils +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io +import tensorflow +from tensorflow import keras + +import os +from pathlib import Path +import gdown + +def loadModel(): + base_model = ResNet34() + inputs = base_model.inputs[0] + arcface_model = base_model.outputs[0] + arcface_model = keras.layers.BatchNormalization(momentum=0.9, epsilon=2e-5)(arcface_model) + arcface_model = keras.layers.Dropout(0.4)(arcface_model) + arcface_model = keras.layers.Flatten()(arcface_model) + arcface_model = keras.layers.Dense(512, activation=None, use_bias=True, kernel_initializer="glorot_normal")(arcface_model) + embedding = keras.layers.BatchNormalization(momentum=0.9, epsilon=2e-5, name="embedding", scale=True)(arcface_model) + model = keras.models.Model(inputs, embedding, name=base_model.name) + + #--------------------------------------- + #check the availability of pre-trained weights + + home = str(Path.home()) + + url = "https://drive.google.com/uc?id=1LVB3CdVejpmGHM28BpqqkbZP5hDEcdZY" + file_name = "arcface_weights.h5" + output = home+'/.deepface/weights/'+file_name + + if os.path.isfile(output) != True: + + print(file_name," will be downloaded to ",output) + gdown.download(url, output, quiet=False) + + #--------------------------------------- + + try: + model.load_weights(output) + except: + print("pre-trained weights could not be loaded.") + print("You might try to download it from the url ", url," and copy to ",output," manually") + + return model + +def ResNet34(): + + img_input = tensorflow.keras.layers.Input(shape=(112, 112, 3)) + + x = tensorflow.keras.layers.ZeroPadding2D(padding=1, name='conv1_pad')(img_input) + x = tensorflow.keras.layers.Conv2D(64, 3, strides=1, use_bias=False, kernel_initializer='glorot_normal', name='conv1_conv')(x) + x = tensorflow.keras.layers.BatchNormalization(axis=3, epsilon=2e-5, momentum=0.9, name='conv1_bn')(x) + x = tensorflow.keras.layers.PReLU(shared_axes=[1, 2], name='conv1_prelu')(x) + x = stack_fn(x) + + model = training.Model(img_input, x, name='ResNet34') + + return model + +def block1(x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None): + bn_axis = 3 + + if conv_shortcut: + shortcut = tensorflow.keras.layers.Conv2D(filters, 1, strides=stride, use_bias=False, kernel_initializer='glorot_normal', name=name + '_0_conv')(x) + shortcut = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_0_bn')(shortcut) + else: + shortcut = x + + x = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_1_bn')(x) + x = tensorflow.keras.layers.ZeroPadding2D(padding=1, name=name + '_1_pad')(x) + x = tensorflow.keras.layers.Conv2D(filters, 3, strides=1, kernel_initializer='glorot_normal', use_bias=False, name=name + '_1_conv')(x) + x = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_2_bn')(x) + x = tensorflow.keras.layers.PReLU(shared_axes=[1, 2], name=name + '_1_prelu')(x) + + x = tensorflow.keras.layers.ZeroPadding2D(padding=1, name=name + '_2_pad')(x) + x = tensorflow.keras.layers.Conv2D(filters, kernel_size, strides=stride, kernel_initializer='glorot_normal', use_bias=False, name=name + '_2_conv')(x) + x = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_3_bn')(x) + + x = tensorflow.keras.layers.Add(name=name + '_add')([shortcut, x]) + return x + +def stack1(x, filters, blocks, stride1=2, name=None): + x = block1(x, filters, stride=stride1, name=name + '_block1') + for i in range(2, blocks + 1): + x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i)) + return x + +def stack_fn(x): + x = stack1(x, 64, 3, name='conv2') + x = stack1(x, 128, 4, name='conv3') + x = stack1(x, 256, 6, name='conv4') + return stack1(x, 512, 3, name='conv5') \ No newline at end of file diff --git a/deepface/basemodels/VGGFace.py b/deepface/basemodels/VGGFace.py index 4870593..1e13aa9 100644 --- a/deepface/basemodels/VGGFace.py +++ b/deepface/basemodels/VGGFace.py @@ -70,16 +70,20 @@ def loadModel(url = 'https://drive.google.com/uc?id=1CPSeum3HpopfomUEK1gybeuIVoe #----------------------------------- home = str(Path.home()) + output = home+'/.deepface/weights/vgg_face_weights.h5' - if os.path.isfile(home+'/.deepface/weights/vgg_face_weights.h5') != True: + if os.path.isfile(output) != True: print("vgg_face_weights.h5 will be downloaded...") - - output = home+'/.deepface/weights/vgg_face_weights.h5' gdown.download(url, output, quiet=False) #----------------------------------- - model.load_weights(home+'/.deepface/weights/vgg_face_weights.h5') + try: + model.load_weights(output) + except Exception as err: + print(str(err)) + print("Pre-trained weight could not be loaded.") + print("You might try to download the pre-trained weights from the url ", url, " and copy it to the ", output) #----------------------------------- diff --git a/deepface/commons/distance.py b/deepface/commons/distance.py index 63d674b..1b364a4 100644 --- a/deepface/commons/distance.py +++ b/deepface/commons/distance.py @@ -25,7 +25,8 @@ def findThreshold(model_name, distance_metric): 'Facenet': {'cosine': 0.40, 'euclidean': 10, 'euclidean_l2': 0.80}, 'DeepFace': {'cosine': 0.23, 'euclidean': 64, 'euclidean_l2': 0.64}, 'DeepID': {'cosine': 0.015, 'euclidean': 45, 'euclidean_l2': 0.17}, - 'Dlib': {'cosine': 0.07, 'euclidean': 0.6, 'euclidean_l2': 0.6} + 'Dlib': {'cosine': 0.07, 'euclidean': 0.6, 'euclidean_l2': 0.6}, + 'ArcFace': {'cosine': 0.6871912959056619, 'euclidean': 4.1591468986978075, 'euclidean_l2': 1.1315718048269017} } threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4) diff --git a/tests/unit_tests.py b/tests/unit_tests.py index 20e9075..3f1efa2 100644 --- a/tests/unit_tests.py +++ b/tests/unit_tests.py @@ -148,7 +148,7 @@ dataset = [ ['dataset/img6.jpg', 'dataset/img9.jpg', False], ] -models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace', 'DeepID', 'Dlib'] +models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace', 'DeepID', 'Dlib', 'ArcFace'] metrics = ['cosine', 'euclidean', 'euclidean_l2'] passed_tests = 0; test_cases = 0