dlib resnet added

This commit is contained in:
Şefik Serangil 2020-08-16 21:24:12 +03:00
parent 0c12758695
commit 3471874b8f
8 changed files with 136 additions and 32 deletions

View File

@ -46,10 +46,10 @@ df = DeepFace.find(img_path = "img1.jpg", db_path = "C:/workspace/my_db")
**Face recognition models**
Deepface is a hybrid face recognition package. It currently wraps the **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`Google FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`Facebook DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/) and [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/). The default configuration verifies faces with **VGG-Face** model. You can set the base model while verification as illustared below.
Deepface is a hybrid face recognition package. It currently wraps the **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`Google FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`Facebook DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/) and [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/). The default configuration verifies faces with **VGG-Face** model. You can set the base model while verification as illustared below.
```python
models = ["VGG-Face", "Facenet", "OpenFace", "DeepFace", "DeepID"]
models = ["VGG-Face", "Facenet", "OpenFace", "DeepFace", "DeepID", "Dlib"]
for model in models:
result = DeepFace.verify("img1.jpg", "img2.jpg", model_name = model)
```

View File

@ -10,6 +10,7 @@ import tensorflow as tf
from deepface import DeepFace
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID
from deepface.basemodels.DlibResNet import DlibResNet
from deepface.extendedmodels import Age, Gender, Race, Emotion
#import DeepFace
@ -26,7 +27,7 @@ tic = time.time()
print("Loading Face Recognition Models...")
pbar = tqdm(range(0,5), desc='Loading Face Recognition Models...')
pbar = tqdm(range(0,6), desc='Loading Face Recognition Models...')
for index in pbar:
if index == 0:
@ -44,6 +45,9 @@ for index in pbar:
elif index == 4:
pbar.set_description("Loading DeepID DeepFace")
deepid_model = DeepID.loadModel()
elif index == 5:
pbar.set_description("Loading Dlib ResNet DeepFace")
dlib_model = DlibResNet()
toc = time.time()
@ -200,6 +204,8 @@ def verify():
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = deepface_model)
elif model_name == "DeepID":
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = deepid_model)
elif model_name == "Dlib":
resp_obj = DeepFace.verify(instances, model_name = model_name, distance_metric = distance_metric, model = dlib_model)
elif model_name == "Ensemble":
models = {}
models["VGG-Face"] = vggface_model

View File

@ -18,6 +18,7 @@ import pickle
from deepface import DeepFace
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID
from deepface.basemodels.DlibResNet import DlibResNet
from deepface.extendedmodels import Age, Gender, Race, Emotion
from deepface.commons import functions, realtime, distance as dst
@ -219,6 +220,10 @@ def verify(img1_path, img2_path=''
elif model_name == 'DeepID':
print("Using DeepID2 model backend", distance_metric,"distance.")
model = DeepID.loadModel()
elif model_name == 'Dlib':
print("Using Dlib ResNet model backend", distance_metric,"distance.")
model = DlibResNet()
else:
raise ValueError("Invalid model_name passed - ", model_name)
@ -227,15 +232,19 @@ def verify(img1_path, img2_path=''
#------------------------------
#face recognition models have different size of inputs
#input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
#my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
if model_name == 'Dlib': #this is not a regular keras model
input_shape = (150, 150, 3)
input_shape = model.layers[0].input_shape
if type(input_shape) == list:
input_shape = input_shape[0][1:3]
else:
input_shape = input_shape[1:3]
else: #keras based models
input_shape = model.layers[0].input_shape
if type(input_shape) == list:
input_shape = input_shape[0][1:3]
else:
input_shape = input_shape[1:3]
input_shape_x = input_shape[0]
input_shape_y = input_shape[1]
@ -536,8 +545,10 @@ def find(img_path, db_path
elif model_name == 'DeepID':
print("Using DeepID model backend", distance_metric,"distance.")
model = DeepID.loadModel()
elif model_name == 'Dlib':
print("Using Dlib ResNet model backend", distance_metric,"distance.")
model = DlibResNet()
elif model_name == 'Ensemble':
print("Ensemble learning enabled")
#TODO: include DeepID in ensemble method
@ -622,15 +633,20 @@ def find(img_path, db_path
if model_name != 'Ensemble':
#input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
input_shape = model.layers[0].input_shape
if type(input_shape) == list:
input_shape = input_shape[0][1:3]
if model_name == 'Dlib': #non-keras model
input_shape = (150, 150, 3)
else:
input_shape = input_shape[1:3]
#input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
input_shape = model.layers[0].input_shape
if type(input_shape) == list:
input_shape = input_shape[0][1:3]
else:
input_shape = input_shape[1:3]
#---------------------
input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
img = functions.detectFace(employee, (input_shape_y, input_shape_x), enforce_detection = enforce_detection)
@ -779,15 +795,20 @@ def find(img_path, db_path
#----------------------------------
if model_name != 'Ensemble':
#input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
input_shape = model.layers[0].input_shape
if type(input_shape) == list:
input_shape = input_shape[0][1:3]
if model_name == 'Dlib': #non-keras model
input_shape = (150, 150, 3)
else:
input_shape = input_shape[1:3]
#input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
input_shape = model.layers[0].input_shape
if type(input_shape) == list:
input_shape = input_shape[0][1:3]
else:
input_shape = input_shape[1:3]
#------------------------
input_shape_x = input_shape[0]; input_shape_y = input_shape[1]

View File

@ -0,0 +1,62 @@
import dlib
import os
import zipfile
import bz2
import gdown
import numpy as np
from pathlib import Path
class DlibResNet:
def __init__(self):
home = str(Path.home())
weight_file = home+'/.deepface/weights/dlib_face_recognition_resnet_model_v1.dat'
#---------------------
#download pre-trained model if it does not exist
if os.path.isfile(weight_file) != True:
print("dlib_face_recognition_resnet_model_v1.dat is going to be downloaded")
url = "http://dlib.net/files/dlib_face_recognition_resnet_model_v1.dat.bz2"
output = home+'/.deepface/weights/'+url.split("/")[-1]
gdown.download(url, output, quiet=False)
zipfile = bz2.BZ2File(output)
data = zipfile.read()
newfilepath = output[:-4] #discard .bz2 extension
open(newfilepath, 'wb').write(data)
#---------------------
model = dlib.face_recognition_model_v1(weight_file)
self.__model = model
#---------------------
return None #classes must return None
def predict(self, img_aligned):
#functions.detectFace returns 4 dimensional images
if len(img_aligned.shape) == 4:
img_aligned = img_aligned[0]
#functions.detectFace returns bgr images
img_aligned = img_aligned[:,:,::-1] #bgr to rgb
#deepface.detectFace returns an array in scale of [0, 1] but dlib expects in scale of [0, 255]
if img_aligned.max() <= 1:
img_aligned = img_aligned * 255
img_aligned = img_aligned.astype(np.uint8)
model = self.__model
img_representation = model.compute_face_descriptor(img_aligned)
img_representation = np.array(img_representation)
img_representation = np.expand_dims(img_representation, axis = 0)
return img_representation

View File

@ -137,6 +137,14 @@ def findThreshold(model_name, distance_metric):
elif distance_metric == 'euclidean_l2':
threshold = 0.17
elif model_name == 'Dlib':
if distance_metric == 'cosine':
threshold = 0.07
elif distance_metric == 'euclidean':
threshold = 0.60
elif distance_metric == 'euclidean_l2':
threshold = 0.60
return threshold
def get_opencv_path():

View File

@ -10,6 +10,7 @@ import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID
from deepface.basemodels.DlibResNet import DlibResNet
from deepface.extendedmodels import Age, Gender, Race, Emotion
from deepface.commons import functions, realtime, distance as dst
@ -58,6 +59,11 @@ def analysis(db_path, model_name, distance_metric, enable_face_analysis = True):
model = DeepID.loadModel()
input_shape = (55, 47)
elif model_name == 'Dlib':
print("Using Dlib model backend", distance_metric,"distance.")
model = DlibResNet()
input_shape = (150, 150)
else:
raise ValueError("Invalid model_name passed - ", model_name)
#------------------------

View File

@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
setuptools.setup(
name="deepface",
version="0.0.33",
version="0.0.34",
author="Sefik Ilkin Serengil",
author_email="serengil@gmail.com",
description="A Lightweight Face Recognition and Facial Attribute Analysis Framework (Age, Gender, Emotion, Race) for Python",

View File

@ -11,7 +11,9 @@ print("-----------------------------------------")
print("Large scale face recognition")
df = DeepFace.find(img_path = "dataset/img1.jpg", db_path = "dataset")
df = DeepFace.find(img_path = "dataset/img1.jpg", db_path = "dataset"
#, model_name = 'Dlib'
)
print(df.head())
print("-----------------------------------------")
@ -105,7 +107,7 @@ dataset = [
['dataset/img6.jpg', 'dataset/img9.jpg', False],
]
models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace', 'DeepID']
models = ['VGG-Face', 'Facenet', 'OpenFace', 'DeepFace', 'DeepID', 'Dlib']
metrics = ['cosine', 'euclidean', 'euclidean_l2']
passed_tests = 0; test_cases = 0
@ -134,7 +136,7 @@ for model in models:
test_cases = test_cases + 1
print(img1, " and ", img2," are ", classified_label, " as same person based on ", model," model and ",metric," distance metric. Distance: ",distance,", Required Threshold: ", required_threshold," (",test_result_label,")")
print(img1.split("/")[-1], "and", img2.split("/")[-1],"are", classified_label, "as same person based on", model,"model and",metric,"distance. Distance:",distance,", Threshold:", required_threshold,"(",test_result_label,")")
print("--------------------------")
@ -178,4 +180,3 @@ facial_attribute_models["gender"] = gender_model
facial_attribute_models["race"] = race_model
resp_obj = DeepFace.analyze("dataset/img1.jpg", models=facial_attribute_models)