Merge pull request #1111 from serengil/feat-task-1003-ghostfacenet-model

Feat task 1003 ghostfacenet model
This commit is contained in:
Sefik Ilkin Serengil 2024-03-16 10:00:08 +00:00 committed by GitHub
commit 167a8ca392
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 368 additions and 30 deletions

View File

@ -21,7 +21,7 @@
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/deepface-icon-labeled.png" width="200" height="240"></p>
Deepface is a lightweight [face recognition](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) and facial attribute analysis ([age](https://sefiks.com/2019/02/13/apparent-age-and-gender-prediction-in-keras/), [gender](https://sefiks.com/2019/02/13/apparent-age-and-gender-prediction-in-keras/), [emotion](https://sefiks.com/2018/01/01/facial-expression-recognition-with-keras/) and [race](https://sefiks.com/2019/11/11/race-and-ethnicity-prediction-in-keras/)) framework for python. It is a hybrid face recognition framework wrapping **state-of-the-art** models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/), [`Google FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`Facebook DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/) and `SFace`.
Deepface is a lightweight [face recognition](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) and facial attribute analysis ([age](https://sefiks.com/2019/02/13/apparent-age-and-gender-prediction-in-keras/), [gender](https://sefiks.com/2019/02/13/apparent-age-and-gender-prediction-in-keras/), [emotion](https://sefiks.com/2018/01/01/facial-expression-recognition-with-keras/) and [race](https://sefiks.com/2019/11/11/race-and-ethnicity-prediction-in-keras/)) framework for python. It is a hybrid face recognition framework wrapping **state-of-the-art** models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/), [`Google FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`Facebook DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace` and `GhostFaceNet`.
Experiments show that human beings have 97.53% accuracy on facial recognition tasks whereas those models already reached and passed that accuracy level.
@ -100,7 +100,7 @@ Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introdu
**Face recognition models** - [`Demo`](https://youtu.be/i_MOwvhbLdI)
Deepface is a **hybrid** face recognition package. It currently wraps many **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`Google FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`Facebook DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/) and `SFace`. The default configuration uses VGG-Face model.
Deepface is a **hybrid** face recognition package. It currently wraps many **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`Google FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`Facebook DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace` and `GhostFaceNet`. The default configuration uses VGG-Face model.
```python
models = [
@ -113,6 +113,7 @@ models = [
"ArcFace",
"Dlib",
"SFace",
"GhostFaceNet",
]
#face verification
@ -135,19 +136,22 @@ embedding_objs = DeepFace.represent(img_path = "img.jpg",
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/model-portfolio-v8.jpg" width="95%" height="95%"></p>
FaceNet, VGG-Face, ArcFace and Dlib are [overperforming](https://youtu.be/i_MOwvhbLdI) ones based on experiments. You can find out the scores of those models below on both [Labeled Faces in the Wild](https://sefiks.com/2020/08/27/labeled-faces-in-the-wild-for-face-recognition/) and YouTube Faces in the Wild data sets declared by its creators.
FaceNet, VGG-Face, ArcFace and Dlib are [overperforming](https://youtu.be/i_MOwvhbLdI) ones based on experiments. You can find out the scores of those models below on [Labeled Faces in the Wild](https://sefiks.com/2020/08/27/labeled-faces-in-the-wild-for-face-recognition/) set declared by its creators.
| Model | LFW Score | YTF Score |
| --- | --- | --- |
| Facenet512 | 99.65% | - |
| SFace | 99.60% | - |
| ArcFace | 99.41% | - |
| Dlib | 99.38 % | - |
| Facenet | 99.20% | - |
| VGG-Face | 98.78% | 97.40% |
| *Human-beings* | *97.53%* | - |
| OpenFace | 93.80% | - |
| DeepID | - | 97.05% |
| Model | Declared LFW Score |
| --- | --- |
| VGG-Face | 98.78% |
| Facenet | 99.20% |
| Facenet512 | 99.65% |
| OpenFace | 93.80% |
| DeepID | - |
| Dlib | 99.38 % |
| SFace | 99.60% |
| ArcFace | 99.41% |
| GhostFaceNet | 99.76 |
| *Human-beings* | *97.53%* |
Conducting experiments with those models within DeepFace may reveal disparities compared to the original studies, owing to the adoption of distinct detection or normalization techniques. Furthermore, some models have been released solely with their backbones, lacking pre-trained weights. Thus, we are utilizing their re-implementations instead of the original pre-trained weights.
**Similarity**
@ -374,6 +378,6 @@ Also, if you use deepface in your GitHub projects, please add `deepface` in the
DeepFace is licensed under the MIT License - see [`LICENSE`](https://github.com/serengil/deepface/blob/master/LICENSE) for more details.
DeepFace wraps some external face recognition models: [VGG-Face](http://www.robots.ox.ac.uk/~vgg/software/vgg_face/), [Facenet](https://github.com/davidsandberg/facenet/blob/master/LICENSE.md), [OpenFace](https://github.com/iwantooxxoox/Keras-OpenFace/blob/master/LICENSE), [DeepFace](https://github.com/swghosh/DeepFace), [DeepID](https://github.com/Ruoyiran/DeepID/blob/master/LICENSE.md), [ArcFace](https://github.com/leondgarse/Keras_insightface/blob/master/LICENSE), [Dlib](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt), and [SFace](https://github.com/opencv/opencv_zoo/blob/master/models/face_recognition_sface/LICENSE). Besides, age, gender and race / ethnicity models were trained on the backbone of VGG-Face with transfer learning. Licence types will be inherited if you are going to use those models. Please check the license types of those models for production purposes.
DeepFace wraps some external face recognition models: [VGG-Face](http://www.robots.ox.ac.uk/~vgg/software/vgg_face/), [Facenet](https://github.com/davidsandberg/facenet/blob/master/LICENSE.md), [OpenFace](https://github.com/iwantooxxoox/Keras-OpenFace/blob/master/LICENSE), [DeepFace](https://github.com/swghosh/DeepFace), [DeepID](https://github.com/Ruoyiran/DeepID/blob/master/LICENSE.md), [ArcFace](https://github.com/leondgarse/Keras_insightface/blob/master/LICENSE), [Dlib](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt), [SFace](https://github.com/opencv/opencv_zoo/blob/master/models/face_recognition_sface/LICENSE) and [`GhostFaceNet`](https://github.com/HamadYA/GhostFaceNets/blob/main/LICENSE). Besides, age, gender and race / ethnicity models were trained on the backbone of VGG-Face with transfer learning. Licence types will be inherited if you are going to use those models. Please check the license types of those models for production purposes.
DeepFace [logo](https://thenounproject.com/term/face-recognition/2965879/) is created by [Adrien Coquet](https://thenounproject.com/coquet_adrien/) and it is licensed under [Creative Commons: By Attribution 3.0 License](https://creativecommons.org/licenses/by/3.0/).

View File

@ -76,7 +76,7 @@ def verify(
or pre-calculated embeddings.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
@ -254,7 +254,7 @@ def find(
in the database will be considered in the decision-making process.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
@ -331,7 +331,8 @@ def represent(
include information for each detected face.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face.).
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
(default is VGG-Face.).
enforce_detection (boolean): If no face is detected in an image, raise an exception.
Default is True. Set to False to avoid the exception for low-resolution images
@ -393,7 +394,7 @@ def stream(
in the database will be considered in the decision-making process.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).

View File

@ -0,0 +1,319 @@
# built-in dependencies
import os
from typing import List
# 3rd party dependencies
import gdown
import numpy as np
import tensorflow as tf
# project dependencies
from deepface.commons import package_utils, folder_utils
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
logger = Logger(module="basemodels.VGGFace")
tf_major = package_utils.get_tf_major_version()
if tf_major == 1:
import keras
from keras import backend as K
from keras.models import Model
from keras.layers import (
Activation,
Add,
BatchNormalization,
Concatenate,
Conv2D,
DepthwiseConv2D,
GlobalAveragePooling2D,
Input,
Reshape,
Multiply,
ReLU,
PReLU,
)
else:
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Activation,
Add,
BatchNormalization,
Concatenate,
Conv2D,
DepthwiseConv2D,
GlobalAveragePooling2D,
Input,
Reshape,
Multiply,
ReLU,
PReLU,
)
# pylint: disable=line-too-long, too-few-public-methods, no-else-return, unsubscriptable-object, comparison-with-callable
PRETRAINED_WEIGHTS = "https://github.com/HamadYA/GhostFaceNets/releases/download/v1.2/GhostFaceNet_W1.3_S1_ArcFace.h5"
class GhostFaceNetClient(FacialRecognition):
"""
GhostFaceNet model (GhostFaceNetV1 backbone)
Repo: https://github.com/HamadYA/GhostFaceNets
Pre-trained weights: https://github.com/HamadYA/GhostFaceNets/releases/tag/v1.2
GhostFaceNet_W1.3_S1_ArcFace.h5 ~ 16.5MB
Author declared that this backbone and pre-trained weights got 99.7667% accuracy on LFW
"""
def __init__(self):
self.model_name = "GhostFaceNet"
self.input_shape = (112, 112)
self.output_shape = 512
self.model = load_model()
def find_embeddings(self, img: np.ndarray) -> List[float]:
# model.predict causes memory issue when it is called in a for loop
# embedding = model.predict(img, verbose=0)[0].tolist()
return self.model(img, training=False).numpy()[0].tolist()
def load_model():
model = GhostFaceNetV1()
home = folder_utils.get_deepface_home()
output = home + "/.deepface/weights/ghostfacenet_v1.h5"
if os.path.isfile(output) is not True:
logger.info("Pre-trained weights is downloaded from {PRETRAINED_WEIGHTS} to {output}")
gdown.download(PRETRAINED_WEIGHTS, output, quiet=False)
logger.info(f"Pre-trained weights is just downloaded to {output}")
model.load_weights(output)
return model
def GhostFaceNetV1() -> Model:
"""
Build GhostFaceNetV1 model. Refactored from
github.com/HamadYA/GhostFaceNets/blob/main/backbones/ghost_model.py
Returns:
model (Model)
"""
inputs = Input(shape=(112, 112, 3))
out_channel = 20
nn = Conv2D(
out_channel,
(3, 3),
strides=1,
padding="same",
use_bias=False,
kernel_initializer=keras.initializers.VarianceScaling(
scale=2.0, mode="fan_out", distribution="truncated_normal"
),
)(inputs)
nn = BatchNormalization(axis=-1)(nn)
nn = Activation("relu")(nn)
dwkernels = [3, 3, 3, 5, 5, 3, 3, 3, 3, 3, 3, 5, 5, 5, 5, 5]
exps = [20, 64, 92, 92, 156, 312, 260, 240, 240, 624, 872, 872, 1248, 1248, 1248, 664]
outs = [20, 32, 32, 52, 52, 104, 104, 104, 104, 144, 144, 208, 208, 208, 208, 208]
strides_set = [1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1]
reductions = [0, 0, 0, 24, 40, 0, 0, 0, 0, 156, 220, 220, 0, 312, 0, 168]
pre_out = out_channel
for dwk, stride, exp, out, reduction in zip(dwkernels, strides_set, exps, outs, reductions):
shortcut = not (out == pre_out and stride == 1)
nn = ghost_bottleneck(nn, dwk, stride, exp, out, reduction, shortcut)
pre_out = out
nn = Conv2D(
664,
(1, 1),
strides=(1, 1),
padding="valid",
use_bias=False,
kernel_initializer=keras.initializers.VarianceScaling(
scale=2.0, mode="fan_out", distribution="truncated_normal"
),
)(nn)
nn = BatchNormalization(axis=-1)(nn)
nn = Activation("relu")(nn)
xx = Model(inputs=inputs, outputs=nn, name="GhostFaceNetV1")
# post modelling
inputs = xx.inputs[0]
nn = xx.outputs[0]
nn = keras.layers.DepthwiseConv2D(nn.shape[1], use_bias=False, name="GDC_dw")(nn)
nn = keras.layers.BatchNormalization(momentum=0.99, epsilon=0.001, name="GDC_batchnorm")(nn)
nn = keras.layers.Conv2D(
512, 1, use_bias=True, kernel_initializer="glorot_normal", name="GDC_conv"
)(nn)
nn = keras.layers.Flatten(name="GDC_flatten")(nn)
embedding = keras.layers.BatchNormalization(
momentum=0.99, epsilon=0.001, scale=True, name="pre_embedding"
)(nn)
embedding_fp32 = keras.layers.Activation("linear", dtype="float32", name="embedding")(embedding)
model = keras.models.Model(inputs, embedding_fp32, name=xx.name)
model = replace_relu_with_prelu(model=model)
return model
def se_module(inputs, reduction):
"""
Refactored from github.com/HamadYA/GhostFaceNets/blob/main/backbones/ghost_model.py
"""
# get the channel axis
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
# filters = channel axis shape
filters = inputs.shape[channel_axis]
# from None x H x W x C to None x C
se = GlobalAveragePooling2D()(inputs)
# Reshape None x C to None 1 x 1 x C
se = Reshape((1, 1, filters))(se)
# Squeeze by using C*se_ratio. The size will be 1 x 1 x C*se_ratio
se = Conv2D(
reduction,
kernel_size=1,
use_bias=True,
kernel_initializer=keras.initializers.VarianceScaling(
scale=2.0, mode="fan_out", distribution="truncated_normal"
),
)(se)
se = Activation("relu")(se)
# Excitation using C filters. The size will be 1 x 1 x C
se = Conv2D(
filters,
kernel_size=1,
use_bias=True,
kernel_initializer=keras.initializers.VarianceScaling(
scale=2.0, mode="fan_out", distribution="truncated_normal"
),
)(se)
se = Activation("hard_sigmoid")(se)
return Multiply()([inputs, se])
def ghost_module(inputs, out, convkernel=1, dwkernel=3, add_activation=True):
"""
Refactored from github.com/HamadYA/GhostFaceNets/blob/main/backbones/ghost_model.py
"""
conv_out_channel = out // 2
cc = Conv2D(
conv_out_channel,
convkernel,
use_bias=False,
strides=(1, 1),
padding="same",
kernel_initializer=keras.initializers.VarianceScaling(
scale=2.0, mode="fan_out", distribution="truncated_normal"
),
)(inputs)
cc = BatchNormalization(axis=-1)(cc)
if add_activation:
cc = Activation("relu")(cc)
nn = DepthwiseConv2D(
dwkernel,
1,
padding="same",
use_bias=False,
depthwise_initializer=keras.initializers.VarianceScaling(
scale=2.0, mode="fan_out", distribution="truncated_normal"
),
)(cc)
nn = BatchNormalization(axis=-1)(nn)
if add_activation:
nn = Activation("relu")(nn)
return Concatenate()([cc, nn])
def ghost_bottleneck(inputs, dwkernel, strides, exp, out, reduction, shortcut=True):
"""
Refactored from github.com/HamadYA/GhostFaceNets/blob/main/backbones/ghost_model.py
"""
nn = ghost_module(inputs, exp, add_activation=True)
if strides > 1:
# Extra depth conv if strides higher than 1
nn = DepthwiseConv2D(
dwkernel,
strides,
padding="same",
use_bias=False,
depthwise_initializer=keras.initializers.VarianceScaling(
scale=2.0, mode="fan_out", distribution="truncated_normal"
),
)(nn)
nn = BatchNormalization(axis=-1)(nn)
if reduction > 0:
# Squeeze and excite
nn = se_module(nn, reduction)
# Point-wise linear projection
nn = ghost_module(nn, out, add_activation=False) # ghost2 = GhostModule(exp, out, relu=False)
if shortcut:
xx = DepthwiseConv2D(
dwkernel,
strides,
padding="same",
use_bias=False,
depthwise_initializer=keras.initializers.VarianceScaling(
scale=2.0, mode="fan_out", distribution="truncated_normal"
),
)(inputs)
xx = BatchNormalization(axis=-1)(xx)
xx = Conv2D(
out,
(1, 1),
strides=(1, 1),
padding="valid",
use_bias=False,
kernel_initializer=keras.initializers.VarianceScaling(
scale=2.0, mode="fan_out", distribution="truncated_normal"
),
)(xx)
xx = BatchNormalization(axis=-1)(xx)
else:
xx = inputs
return Add()([xx, nn])
def replace_relu_with_prelu(model) -> Model:
"""
Replaces relu activation function in the built model with prelu.
Refactored from github.com/HamadYA/GhostFaceNets/blob/main/backbones/ghost_model.py
Args:
model (Model): built model with relu activation functions
Returns
model (Model): built model with prelu activation functions
"""
def convert_relu(layer):
if isinstance(layer, ReLU) or (
isinstance(layer, Activation) and layer.activation == keras.activations.relu
):
layer_name = layer.name.replace("_relu", "_prelu")
return PReLU(
shared_axes=[1, 2],
alpha_initializer=tf.initializers.Constant(0.25),
name=layer_name,
)
return layer
input_tensors = keras.layers.Input(model.input_shape[1:])
return keras.models.clone_model(model, input_tensors=input_tensors, clone_function=convert_relu)

View File

@ -2,7 +2,17 @@
from typing import Any
# project dependencies
from deepface.basemodels import VGGFace, OpenFace, FbDeepFace, DeepID, ArcFace, SFace, Dlib, Facenet
from deepface.basemodels import (
VGGFace,
OpenFace,
FbDeepFace,
DeepID,
ArcFace,
SFace,
Dlib,
Facenet,
GhostFaceNet
)
from deepface.extendedmodels import Age, Gender, Race, Emotion
@ -31,6 +41,7 @@ def build_model(model_name: str) -> Any:
"Dlib": Dlib.DlibClient,
"ArcFace": ArcFace.ArcFaceClient,
"SFace": SFace.SFaceClient,
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient,
"Emotion": Emotion.EmotionClient,
"Age": Age.ApparentAgeClient,
"Gender": Gender.GenderClient,

View File

@ -43,7 +43,7 @@ def find(
in the database will be considered in the decision-making process.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2'.
@ -319,7 +319,8 @@ def __find_bulk_embeddings(
Args:
employees (list): list of exact image paths
model_name (str): facial recognition model name
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
target_size (tuple): expected input shape of facial recognition model

View File

@ -28,7 +28,7 @@ def represent(
include information for each detected face.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
enforce_detection (boolean): If no face is detected in an image, raise an exception.
Default is True. Set to False to avoid the exception for low-resolution images.

View File

@ -41,7 +41,7 @@ def analysis(
in the database will be considered in the decision-making process.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
@ -161,7 +161,7 @@ def build_facial_recognition_model(model_name: str) -> tuple:
Build facial recognition model
Args:
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
Returns
input_shape (tuple): input shape of given facial recognitio n model.
"""
@ -184,7 +184,7 @@ def search_identity(
db_path (string): Path to the folder containing image files. All detected faces
in the database will be considered in the decision-making process.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
@ -424,7 +424,7 @@ def perform_facial_recognition(
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
Returns:
img (np.ndarray): image with identified face informations
"""

View File

@ -42,7 +42,7 @@ def verify(
or pre-calculated embeddings.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv)
@ -343,7 +343,8 @@ def find_threshold(model_name: str, distance_metric: str) -> float:
"""
Retrieve pre-tuned threshold values for a model and distance metric pair
Args:
model_name (str): facial recognition model name
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
distance_metric (str): distance metric name. Options are cosine, euclidean
and euclidean_l2.
Returns:
@ -368,6 +369,7 @@ def find_threshold(model_name: str, distance_metric: str) -> float:
"OpenFace": {"cosine": 0.10, "euclidean": 0.55, "euclidean_l2": 0.55},
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64},
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17},
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10},
}
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)

View File

@ -6,7 +6,7 @@ from deepface.commons.logger import Logger
logger = Logger("tests/test_facial_recognition_models.py")
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace"]
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet"]
metrics = ["cosine", "euclidean", "euclidean_l2"]
detectors = ["opencv", "mtcnn"]