mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
arcface added
This commit is contained in:
parent
3e93c47d09
commit
5cfdb3b179
42
api/api.py
42
api/api.py
@ -10,13 +10,6 @@ import tensorflow as tf
|
||||
tf_version = int(tf.__version__.split(".")[0])
|
||||
|
||||
from deepface import DeepFace
|
||||
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID
|
||||
from deepface.basemodels.DlibResNet import DlibResNet
|
||||
from deepface.extendedmodels import Age, Gender, Race, Emotion
|
||||
|
||||
#import DeepFace
|
||||
#from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace
|
||||
#from extendedmodels import Age, Gender, Race, Emotion
|
||||
|
||||
#------------------------------
|
||||
|
||||
@ -28,28 +21,29 @@ tic = time.time()
|
||||
|
||||
print("Loading Face Recognition Models...")
|
||||
|
||||
pbar = tqdm(range(0,6), desc='Loading Face Recognition Models...')
|
||||
pbar = tqdm(range(0, 6), desc='Loading Face Recognition Models...')
|
||||
|
||||
for index in pbar:
|
||||
|
||||
if index == 0:
|
||||
pbar.set_description("Loading VGG-Face")
|
||||
vggface_model = VGGFace.loadModel()
|
||||
vggface_model = DeepFace.build_model("VGG-Face")
|
||||
elif index == 1:
|
||||
pbar.set_description("Loading OpenFace")
|
||||
openface_model = OpenFace.loadModel()
|
||||
openface_model = DeepFace.build_model("OpenFace")
|
||||
elif index == 2:
|
||||
pbar.set_description("Loading Google FaceNet")
|
||||
facenet_model = Facenet.loadModel()
|
||||
facenet_model = DeepFace.build_model("Facenet")
|
||||
elif index == 3:
|
||||
pbar.set_description("Loading Facebook DeepFace")
|
||||
deepface_model = FbDeepFace.loadModel()
|
||||
deepface_model = DeepFace.build_model("DeepFace")
|
||||
elif index == 4:
|
||||
pbar.set_description("Loading DeepID DeepFace")
|
||||
deepid_model = DeepID.loadModel()
|
||||
deepid_model = DeepFace.build_model("DeepID")
|
||||
elif index == 5:
|
||||
pbar.set_description("Loading Dlib ResNet DeepFace")
|
||||
dlib_model = DlibResNet()
|
||||
|
||||
pbar.set_description("Loading ArcFace DeepFace")
|
||||
arcface_model = DeepFace.build_model("ArcFace")
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print("Face recognition models are built in ", toc-tic," seconds")
|
||||
@ -65,16 +59,16 @@ pbar = tqdm(range(0,4), desc='Loading Facial Attribute Analysis Models...')
|
||||
for index in pbar:
|
||||
if index == 0:
|
||||
pbar.set_description("Loading emotion analysis model")
|
||||
emotion_model = Emotion.loadModel()
|
||||
emotion_model = DeepFace.build_model('Emotion')
|
||||
elif index == 1:
|
||||
pbar.set_description("Loading age prediction model")
|
||||
age_model = Age.loadModel()
|
||||
age_model = DeepFace.build_model('Age')
|
||||
elif index == 2:
|
||||
pbar.set_description("Loading gender prediction model")
|
||||
gender_model = Gender.loadModel()
|
||||
gender_model = DeepFace.build_model('Gender')
|
||||
elif index == 3:
|
||||
pbar.set_description("Loading race prediction model")
|
||||
race_model = Race.loadModel()
|
||||
race_model = DeepFace.build_model('Race')
|
||||
|
||||
toc = time.time()
|
||||
|
||||
@ -231,19 +225,17 @@ def verifyWrapper(req, trx_id = 0):
|
||||
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = deepface_model)
|
||||
elif model_name == "DeepID":
|
||||
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = deepid_model)
|
||||
elif model_name == "Dlib":
|
||||
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = dlib_model)
|
||||
elif model_name == "ArcFace":
|
||||
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = arcface_model)
|
||||
elif model_name == "Ensemble":
|
||||
models = {}
|
||||
models["VGG-Face"] = vggface_model
|
||||
models["Facenet"] = facenet_model
|
||||
models["OpenFace"] = openface_model
|
||||
models["DeepFace"] = deepface_model
|
||||
|
||||
resp_obj = DeepFace.verify(instances, model_name = model_name, model = models)
|
||||
|
||||
else:
|
||||
resp_obj = jsonify({'success': False, 'error': 'You must pass a valid model name. Available models are VGG-Face, Facenet, OpenFace, DeepFace but you passed %s' % (model_name)}), 205
|
||||
resp_obj = jsonify({'success': False, 'error': 'You must pass a valid model name. You passed %s' % (model_name)}), 205
|
||||
|
||||
return resp_obj
|
||||
|
||||
|
@ -9,7 +9,7 @@ import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import pickle
|
||||
|
||||
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID, DlibWrapper, Boosting
|
||||
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID, DlibWrapper, ArcFace, Boosting
|
||||
from deepface.extendedmodels import Age, Gender, Race, Emotion
|
||||
from deepface.commons import functions, realtime, distance as dst
|
||||
|
||||
@ -33,6 +33,7 @@ def build_model(model_name):
|
||||
'DeepFace': FbDeepFace.loadModel,
|
||||
'DeepID': DeepID.loadModel,
|
||||
'Dlib': DlibWrapper.loadModel,
|
||||
'ArcFace': ArcFace.loadModel,
|
||||
'Emotion': Emotion.loadModel,
|
||||
'Age': Age.loadModel,
|
||||
'Gender': Gender.loadModel,
|
||||
@ -62,7 +63,7 @@ def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric =
|
||||
['img2.jpg', 'img3.jpg']
|
||||
]
|
||||
|
||||
model_name (string): VGG-Face, Facenet, OpenFace, DeepFace, DeepID, Dlib or Ensemble
|
||||
model_name (string): VGG-Face, Facenet, OpenFace, DeepFace, DeepID, Dlib, ArcFace or Ensemble
|
||||
|
||||
distance_metric (string): cosine, euclidean, euclidean_l2
|
||||
|
||||
|
94
deepface/basemodels/ArcFace.py
Normal file
94
deepface/basemodels/ArcFace.py
Normal file
@ -0,0 +1,94 @@
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.lib.io import file_io
|
||||
import tensorflow
|
||||
from tensorflow import keras
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import gdown
|
||||
|
||||
def loadModel():
|
||||
base_model = ResNet34()
|
||||
inputs = base_model.inputs[0]
|
||||
arcface_model = base_model.outputs[0]
|
||||
arcface_model = keras.layers.BatchNormalization(momentum=0.9, epsilon=2e-5)(arcface_model)
|
||||
arcface_model = keras.layers.Dropout(0.4)(arcface_model)
|
||||
arcface_model = keras.layers.Flatten()(arcface_model)
|
||||
arcface_model = keras.layers.Dense(512, activation=None, use_bias=True, kernel_initializer="glorot_normal")(arcface_model)
|
||||
embedding = keras.layers.BatchNormalization(momentum=0.9, epsilon=2e-5, name="embedding", scale=True)(arcface_model)
|
||||
model = keras.models.Model(inputs, embedding, name=base_model.name)
|
||||
|
||||
#---------------------------------------
|
||||
#check the availability of pre-trained weights
|
||||
|
||||
home = str(Path.home())
|
||||
|
||||
url = "https://drive.google.com/uc?id=1LVB3CdVejpmGHM28BpqqkbZP5hDEcdZY"
|
||||
file_name = "arcface_weights.h5"
|
||||
output = home+'/.deepface/weights/'+file_name
|
||||
|
||||
if os.path.isfile(output) != True:
|
||||
|
||||
print(file_name," will be downloaded to ",output)
|
||||
gdown.download(url, output, quiet=False)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
try:
|
||||
model.load_weights(output)
|
||||
except:
|
||||
print("pre-trained weights could not be loaded.")
|
||||
print("You might try to download it from the url ", url," and copy to ",output," manually")
|
||||
|
||||
return model
|
||||
|
||||
def ResNet34():
|
||||
|
||||
img_input = tensorflow.keras.layers.Input(shape=(112, 112, 3))
|
||||
|
||||
x = tensorflow.keras.layers.ZeroPadding2D(padding=1, name='conv1_pad')(img_input)
|
||||
x = tensorflow.keras.layers.Conv2D(64, 3, strides=1, use_bias=False, kernel_initializer='glorot_normal', name='conv1_conv')(x)
|
||||
x = tensorflow.keras.layers.BatchNormalization(axis=3, epsilon=2e-5, momentum=0.9, name='conv1_bn')(x)
|
||||
x = tensorflow.keras.layers.PReLU(shared_axes=[1, 2], name='conv1_prelu')(x)
|
||||
x = stack_fn(x)
|
||||
|
||||
model = training.Model(img_input, x, name='ResNet34')
|
||||
|
||||
return model
|
||||
|
||||
def block1(x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None):
|
||||
bn_axis = 3
|
||||
|
||||
if conv_shortcut:
|
||||
shortcut = tensorflow.keras.layers.Conv2D(filters, 1, strides=stride, use_bias=False, kernel_initializer='glorot_normal', name=name + '_0_conv')(x)
|
||||
shortcut = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_0_bn')(shortcut)
|
||||
else:
|
||||
shortcut = x
|
||||
|
||||
x = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_1_bn')(x)
|
||||
x = tensorflow.keras.layers.ZeroPadding2D(padding=1, name=name + '_1_pad')(x)
|
||||
x = tensorflow.keras.layers.Conv2D(filters, 3, strides=1, kernel_initializer='glorot_normal', use_bias=False, name=name + '_1_conv')(x)
|
||||
x = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_2_bn')(x)
|
||||
x = tensorflow.keras.layers.PReLU(shared_axes=[1, 2], name=name + '_1_prelu')(x)
|
||||
|
||||
x = tensorflow.keras.layers.ZeroPadding2D(padding=1, name=name + '_2_pad')(x)
|
||||
x = tensorflow.keras.layers.Conv2D(filters, kernel_size, strides=stride, kernel_initializer='glorot_normal', use_bias=False, name=name + '_2_conv')(x)
|
||||
x = tensorflow.keras.layers.BatchNormalization(axis=bn_axis, epsilon=2e-5, momentum=0.9, name=name + '_3_bn')(x)
|
||||
|
||||
x = tensorflow.keras.layers.Add(name=name + '_add')([shortcut, x])
|
||||
return x
|
||||
|
||||
def stack1(x, filters, blocks, stride1=2, name=None):
|
||||
x = block1(x, filters, stride=stride1, name=name + '_block1')
|
||||
for i in range(2, blocks + 1):
|
||||
x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i))
|
||||
return x
|
||||
|
||||
def stack_fn(x):
|
||||
x = stack1(x, 64, 3, name='conv2')
|
||||
x = stack1(x, 128, 4, name='conv3')
|
||||
x = stack1(x, 256, 6, name='conv4')
|
||||
return stack1(x, 512, 3, name='conv5')
|
@ -70,16 +70,20 @@ def loadModel(url = 'https://drive.google.com/uc?id=1CPSeum3HpopfomUEK1gybeuIVoe
|
||||
#-----------------------------------
|
||||
|
||||
home = str(Path.home())
|
||||
output = home+'/.deepface/weights/vgg_face_weights.h5'
|
||||
|
||||
if os.path.isfile(home+'/.deepface/weights/vgg_face_weights.h5') != True:
|
||||
if os.path.isfile(output) != True:
|
||||
print("vgg_face_weights.h5 will be downloaded...")
|
||||
|
||||
output = home+'/.deepface/weights/vgg_face_weights.h5'
|
||||
gdown.download(url, output, quiet=False)
|
||||
|
||||
#-----------------------------------
|
||||
|
||||
model.load_weights(home+'/.deepface/weights/vgg_face_weights.h5')
|
||||
try:
|
||||
model.load_weights(output)
|
||||
except Exception as err:
|
||||
print(str(err))
|
||||
print("Pre-trained weight could not be loaded.")
|
||||
print("You might try to download the pre-trained weights from the url ", url, " and copy it to the ", output)
|
||||
|
||||
#-----------------------------------
|
||||
|
||||
|
@ -25,7 +25,8 @@ def findThreshold(model_name, distance_metric):
|
||||
'Facenet': {'cosine': 0.40, 'euclidean': 10, 'euclidean_l2': 0.80},
|
||||
'DeepFace': {'cosine': 0.23, 'euclidean': 64, 'euclidean_l2': 0.64},
|
||||
'DeepID': {'cosine': 0.015, 'euclidean': 45, 'euclidean_l2': 0.17},
|
||||
'Dlib': {'cosine': 0.07, 'euclidean': 0.6, 'euclidean_l2': 0.6}
|
||||
'Dlib': {'cosine': 0.07, 'euclidean': 0.6, 'euclidean_l2': 0.6},
|
||||
'ArcFace': {'cosine': 0.6871912959056619, 'euclidean': 4.1591468986978075, 'euclidean_l2': 1.1315718048269017}
|
||||
}
|
||||
|
||||
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)
|
||||
|
@ -148,7 +148,7 @@ dataset = [
|
||||
['dataset/img6.jpg', 'dataset/img9.jpg', False],
|
||||
]
|
||||
|
||||
models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace', 'DeepID', 'Dlib']
|
||||
models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace', 'DeepID', 'Dlib', 'ArcFace']
|
||||
metrics = ['cosine', 'euclidean', 'euclidean_l2']
|
||||
|
||||
passed_tests = 0; test_cases = 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user