Merge branch 'master' into feat-task-0104-benchmarks

This commit is contained in:
Sefik Ilkin Serengil 2024-04-30 20:22:39 +01:00 committed by GitHub
commit 76b62ba0b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 953 additions and 446 deletions

View File

@ -0,0 +1,86 @@
name: '🐛 Report a bug'
description: 'Use this template to report DeepFace related issues'
title: '[BUG]: <short description of the issue>'
labels:
- bug
body:
- type: checkboxes
id: preliminary-checks
attributes:
label: Before You Report a Bug, Please Confirm You Have Done The Following...
description: If any of these required steps are not taken, we may not be able to review your issue. Help us to help you!
options:
- label: I have updated to the latest version of the packages.
required: true
- label: I have searched for both [existing issues](https://github.com/serengil/deepface/issues) and [closed issues](https://github.com/serengil/deepface/issues?q=is%3Aissue+is%3Aclosed) and found none that matched my issue.
required: true
- type: input
id: deepface-version
attributes:
label: DeepFace's version
description: |
Please provide your deepface version with calling the command `python -c "import deepface; print(deepface.__version__)"` in your terminal
placeholder: e.g. v0.0.90
validations:
required: true
- type: input
id: python-version
attributes:
label: Python version
description: |
Please provide your python programming language's version with calling `python --version` in your terminal
placeholder: e.g. 3.8.5
validations:
required: true
- type: input
id: os
attributes:
label: Operating System
description: |
Please provide your operation system's details
placeholder: e.g. Windows 10 or Ubuntu 20.04
validations:
required: false
- type: textarea
id: dependencies
attributes:
label: Dependencies
description: |
Please provide python dependencies with calling `pip freeze` in your terminal, in particular tensorflow's and keras' versions
validations:
required: true
- type: textarea
id: repro-code
attributes:
label: Reproducible example
description: A ***minimal*** code sample which reproduces the issue
render: Python
validations:
required: true
- type: textarea
id: exception-message
attributes:
label: Relevant Log Output
description: Please share the exception message from your terminal if your program is failing
validations:
required: false
- type: textarea
id: expected
attributes:
label: Expected Result
description: What did you expect to happen?
validations:
required: false
- type: textarea
id: actual
attributes:
label: What happened instead?
description: What actually happened?
validations:
required: false
- type: textarea
id: additional
attributes:
label: Additional Info
description: |
Any additional info you'd like to provide.

View File

@ -0,0 +1,18 @@
name: '✨ Request a New Feature'
description: 'Use this template to propose a new feature'
title: '[FEATURE]: <a short description of my proposal>'
labels:
- 'enhancement'
body:
- type: textarea
id: description
attributes:
label: Description
description: Explain what your proposed feature would do and why this is useful.
validations:
required: true
- type: textarea
id: additional
attributes:
label: Additional Info
description: Any additional info you'd like to provide.

View File

@ -0,0 +1,18 @@
name: '📝 Documentation'
description: 'Use this template to add or improve docs'
title: '[DOC]: <a short description of my proposal>'
labels:
- documentation
body:
- type: textarea
attributes:
label: Suggested Changes
description: What would you like to see happen in the docs?
validations:
required: true
- type: textarea
id: additional
attributes:
label: Additional Info
description: |
Any additional info you'd like to provide.

5
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@ -0,0 +1,5 @@
blank_issues_enabled: false
contact_links:
- name: Ask a question on StackOverflow
about: If you just want to ask a question, consider asking it on StackOverflow!
url: https://stackoverflow.com/search?tab=newest&q=deepface

View File

@ -21,7 +21,7 @@
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/deepface-icon-labeled.png" width="200" height="240"></p>
Deepface is a lightweight [face recognition](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) and facial attribute analysis ([age](https://sefiks.com/2019/02/13/apparent-age-and-gender-prediction-in-keras/), [gender](https://sefiks.com/2019/02/13/apparent-age-and-gender-prediction-in-keras/), [emotion](https://sefiks.com/2018/01/01/facial-expression-recognition-with-keras/) and [race](https://sefiks.com/2019/11/11/race-and-ethnicity-prediction-in-keras/)) framework for python. It is a hybrid face recognition framework wrapping **state-of-the-art** models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/), [`Google FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`Facebook DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace` and `GhostFaceNet`.
Deepface is a lightweight [face recognition](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) and facial attribute analysis ([age](https://sefiks.com/2019/02/13/apparent-age-and-gender-prediction-in-keras/), [gender](https://sefiks.com/2019/02/13/apparent-age-and-gender-prediction-in-keras/), [emotion](https://sefiks.com/2018/01/01/facial-expression-recognition-with-keras/) and [race](https://sefiks.com/2019/11/11/race-and-ethnicity-prediction-in-keras/)) framework for python. It is a hybrid face recognition framework wrapping **state-of-the-art** models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/), [`FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace` and `GhostFaceNet`.
Experiments show that human beings have 97.53% accuracy on facial recognition tasks whereas those models already reached and passed that accuracy level.
@ -91,7 +91,7 @@ This function returns an array as embedding. The size of the embedding array wou
```python
embedding = embedding_objs[0]["embedding"]
assert isinstance(embedding, list)
assert model_name = "VGG-Face" and len(embedding) == 4096
assert model_name == "VGG-Face" and len(embedding) == 4096
```
Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introduction-to-face-recognition-in-deep-learning/) with 4096 slots horizontally. Each slot is corresponding to a dimension value in the embedding vector and dimension value is explained in the colorbar on the right. Similar to 2D barcodes, vertical dimension stores no information in the illustration.
@ -100,7 +100,7 @@ Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introdu
**Face recognition models** - [`Demo`](https://youtu.be/i_MOwvhbLdI)
Deepface is a **hybrid** face recognition package. It currently wraps many **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`Google FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`Facebook DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace` and `GhostFaceNet`. The default configuration uses VGG-Face model.
Deepface is a **hybrid** face recognition package. It currently wraps many **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace` and `GhostFaceNet`. The default configuration uses VGG-Face model.
```python
models = [
@ -195,9 +195,9 @@ Age model got ± 4.65 MAE; gender model got 97.44% accuracy, 96.29% precision an
**Face Detectors** - [`Demo`](https://youtu.be/GZ2p2hj2H5k)
Face detection and alignment are important early stages of a modern face recognition pipeline. Experiments show that just alignment increases the face recognition accuracy almost 1%. [`OpenCV`](https://sefiks.com/2020/02/23/face-alignment-for-face-recognition-in-python-within-opencv/), [`SSD`](https://sefiks.com/2020/08/25/deep-face-detection-with-opencv-in-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), [`MTCNN`](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/), [`Faster MTCNN`](https://github.com/timesler/facenet-pytorch), [`RetinaFace`](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/), [`MediaPipe`](https://sefiks.com/2022/01/14/deep-face-detection-with-mediapipe/), `Yolo` and `YuNet` detectors are wrapped in deepface.
Face detection and alignment are important early stages of a modern face recognition pipeline. Experiments show that just alignment increases the face recognition accuracy almost 1%. [`OpenCV`](https://sefiks.com/2020/02/23/face-alignment-for-face-recognition-in-python-within-opencv/), [`Ssd`](https://sefiks.com/2020/08/25/deep-face-detection-with-opencv-in-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), [`MtCnn`](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/), `Faster MtCnn`, [`RetinaFace`](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/), [`MediaPipe`](https://sefiks.com/2022/01/14/deep-face-detection-with-mediapipe/), `Yolo`, `YuNet` and `CenterFace` detectors are wrapped in deepface.
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/detector-portfolio-v5.jpg" width="95%" height="95%"></p>
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/detector-portfolio-v6.jpg" width="95%" height="95%"></p>
All deepface functions accept an optional detector backend input argument. You can switch among those detectors with this argument. OpenCV is the default detector.
@ -207,11 +207,12 @@ backends = [
'ssd',
'dlib',
'mtcnn',
'fastmtcnn',
'retinaface',
'mediapipe',
'yolov8',
'yunet',
'fastmtcnn',
'centerface',
]
#face verification
@ -238,16 +239,15 @@ demographies = DeepFace.analyze(img_path = "img4.jpg",
#face detection and alignment
face_objs = DeepFace.extract_faces(img_path = "img.jpg",
target_size = (224, 224),
detector_backend = backends[4]
)
```
Face recognition models are actually CNN models and they expect standard sized inputs. So, resizing is required before representation. To avoid deformation, deepface adds black padding pixels according to the target size argument after detection and alignment.
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/detector-outputs-20240302.jpg" width="90%" height="90%"></p>
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/detector-outputs-20240414.jpg" width="90%" height="90%"></p>
[RetinaFace](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/) and [MTCNN](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/) seem to overperform in detection and alignment stages but they are much slower. If the speed of your pipeline is more important, then you should use opencv or ssd. On the other hand, if you consider the accuracy, then you should use retinaface or mtcnn.
[RetinaFace](https://sefiks.com/2021/04/27/deep-face-detection-with-retinaface-in-python/) and [MtCnn](https://sefiks.com/2020/09/09/deep-face-detection-with-mtcnn-in-python/) seem to overperform in detection and alignment stages but they are much slower. If the speed of your pipeline is more important, then you should use opencv or ssd. On the other hand, if you consider the accuracy, then you should use retinaface or mtcnn.
The performance of RetinaFace is very satisfactory even in the crowd as seen in the following illustration. Besides, it comes with an incredible facial landmark detection performance. Highlighted red points show some facial landmarks such as eyes, nose and mouth. That's why, alignment score of RetinaFace is high as well.
@ -317,12 +317,6 @@ $ deepface analyze -img_path tests/dataset/img1.jpg
You can also run these commands if you are running deepface with docker. Please follow the instructions in the [shell script](https://github.com/serengil/deepface/blob/master/scripts/dockerize.sh#L17).
## FAQ and Troubleshooting
If you believe you have identified a bug or encountered a limitation in DeepFace that is not covered in the [existing issues](https://github.com/serengil/deepface/issues) or [closed issues](https://github.com/serengil/deepface/issues?q=is%3Aissue+is%3Aclosed), kindly open a new issue. Ensure that your submission includes clear and detailed reproduction steps, such as your Python version, your DeepFace version (provided by `DeepFace.__version__`), versions of dependent packages (provided by pip freeze), specifics of any exception messages, details about how you are calling DeepFace, and the input image(s) you are using.
Additionally, it is possible to encounter issues due to recently released dependencies, primarily Python itself or TensorFlow. It is recommended to synchronize your dependencies with the versions [specified in my environment](https://github.com/serengil/deepface/blob/master/requirements_local) and [same python version](https://github.com/serengil/deepface/blob/master/Dockerfile#L2) not to have potential compatibility issues.
## Contribution
Pull requests are more than welcome! If you are planning to contribute a large patch, please create an issue first to get any upfront questions or design decisions out of the way first.
@ -394,10 +388,7 @@ Also, if you use deepface in your GitHub projects, please add `deepface` in the
DeepFace is licensed under the MIT License - see [`LICENSE`](https://github.com/serengil/deepface/blob/master/LICENSE) for more details.
DeepFace wraps some external face recognition models: [VGG-Face](http://www.robots.ox.ac.uk/~vgg/software/vgg_face/), [Facenet](https://github.com/davidsandberg/facenet/blob/master/LICENSE.md), [OpenFace](https://github.com/iwantooxxoox/Keras-OpenFace/blob/master/LICENSE), [DeepFace](https://github.com/swghosh/DeepFace), [DeepID](https://github.com/Ruoyiran/DeepID/blob/master/LICENSE.md), [ArcFace](https://github.com/leondgarse/Keras_insightface/blob/master/LICENSE), [Dlib](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt), [SFace](https://github.com/opencv/opencv_zoo/blob/master/models/face_recognition_sface/LICENSE) and [GhostFaceNet](https://github.com/HamadYA/GhostFaceNets/blob/main/LICENSE). Besides, age, gender and race / ethnicity models were trained on the backbone of VGG-Face with transfer learning.
DeepFace wraps some external face recognition models: [VGG-Face](http://www.robots.ox.ac.uk/~vgg/software/vgg_face/), [Facenet](https://github.com/davidsandberg/facenet/blob/master/LICENSE.md), [OpenFace](https://github.com/iwantooxxoox/Keras-OpenFace/blob/master/LICENSE), [DeepFace](https://github.com/swghosh/DeepFace), [DeepID](https://github.com/Ruoyiran/DeepID/blob/master/LICENSE.md), [ArcFace](https://github.com/leondgarse/Keras_insightface/blob/master/LICENSE), [Dlib](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt), [SFace](https://github.com/opencv/opencv_zoo/blob/master/models/face_recognition_sface/LICENSE) and [GhostFaceNet](https://github.com/HamadYA/GhostFaceNets/blob/main/LICENSE). Besides, age, gender and race / ethnicity models were trained on the backbone of VGG-Face with transfer learning. Similarly, DeepFace wraps many face detectors: [OpenCv](https://github.com/opencv/opencv/blob/4.x/LICENSE), [Ssd](https://github.com/opencv/opencv/blob/master/LICENSE), [Dlib](https://github.com/davisking/dlib/blob/master/LICENSE.txt), [MtCnn](https://github.com/ipazc/mtcnn/blob/master/LICENSE), [Fast MtCnn](https://github.com/timesler/facenet-pytorch/blob/master/LICENSE.md), [RetinaFace](https://github.com/serengil/retinaface/blob/master/LICENSE), [MediaPipe](https://github.com/google/mediapipe/blob/master/LICENSE), [YuNet](https://github.com/ShiqiYu/libfacedetection/blob/master/LICENSE), [Yolo](https://github.com/derronqi/yolov8-face/blob/main/LICENSE) and [CenterFace](https://github.com/Star-Clouds/CenterFace/blob/master/LICENSE). License types will be inherited when you intend to utilize those models. Please check the license types of those models for production purposes.
Similarly, DeepFace wraps many face detectors: [OpenCv](https://github.com/opencv/opencv/blob/4.x/LICENSE), [Ssd](https://github.com/opencv/opencv/blob/master/LICENSE), [Dlib](https://github.com/davisking/dlib/blob/master/LICENSE.txt), [MtCnn](https://github.com/ipazc/mtcnn/blob/master/LICENSE), [Fast MtCnn](https://github.com/timesler/facenet-pytorch/blob/master/LICENSE.md), [RetinaFace](https://github.com/serengil/retinaface/blob/master/LICENSE), [MediaPipe](https://github.com/google/mediapipe/blob/master/LICENSE), [YuNet](https://github.com/ShiqiYu/libfacedetection/blob/master/LICENSE) and [Yolo](https://github.com/derronqi/yolov8-face/blob/main/LICENSE).
Licence types will be inherited if you are going to use those models. Please check the license types of those models for production purposes.
DeepFace [logo](https://thenounproject.com/term/face-recognition/2965879/) is created by [Adrien Coquet](https://thenounproject.com/coquet_adrien/) and it is licensed under [Creative Commons: By Attribution 3.0 License](https://creativecommons.org/licenses/by/3.0/).

View File

@ -2,7 +2,7 @@
import os
import warnings
import logging
from typing import Any, Dict, List, Tuple, Union, Optional
from typing import Any, Dict, List, Union, Optional
# this has to be set before importing tensorflow
os.environ["TF_USE_LEGACY_KERAS"] = "1"
@ -16,7 +16,7 @@ import tensorflow as tf
# package dependencies
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.commons import logger as log
from deepface.modules import (
modeling,
representation,
@ -25,10 +25,11 @@ from deepface.modules import (
demography,
detection,
streaming,
preprocessing,
)
from deepface import __version__
logger = Logger(module="DeepFace")
logger = log.get_singletonish_logger()
# -----------------------------------
# configurations for dependencies
@ -71,6 +72,7 @@ def verify(
expand_percentage: int = 0,
normalization: str = "base",
silent: bool = False,
threshold: Optional[float] = None,
) -> Dict[str, Any]:
"""
Verify if an image pair represents the same person or different persons.
@ -87,7 +89,8 @@ def verify(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
@ -105,6 +108,11 @@ def verify(
silent (boolean): Suppress or allow some log messages for a quieter analysis process
(default is False).
threshold (float): Specify a threshold to determine whether a pair represents the same
person or different individuals. This threshold is used for comparing distances.
If left unset, default pre-tuned threshold values will be applied based on the specified
model name and distance metric (default is None).
Returns:
result (dict): A dictionary containing verification results with following keys.
@ -141,6 +149,7 @@ def verify(
expand_percentage=expand_percentage,
normalization=normalization,
silent=silent,
threshold=threshold,
)
@ -167,7 +176,8 @@ def analyze(
Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
@ -271,7 +281,8 @@ def find(
Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
align (boolean): Perform alignment based on the eye positions (default is True).
@ -347,7 +358,8 @@ def represent(
(default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
align (boolean): Perform alignment based on the eye positions (default is True).
@ -405,7 +417,8 @@ def stream(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
@ -439,7 +452,6 @@ def stream(
def extract_faces(
img_path: Union[str, np.ndarray],
target_size: Optional[Tuple[int, int]] = (224, 224),
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,
@ -453,11 +465,9 @@ def extract_faces(
img_path (str or np.ndarray): Path to the first image. Accepts exact image path
as a string, numpy array (BGR), or base64 encoded images.
target_size (tuple): final shape of facial image. black pixels will be
added to resize the image (default is (224, 224)).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True).
@ -485,13 +495,11 @@ def extract_faces(
return detection.extract_faces(
img_path=img_path,
target_size=target_size,
detector_backend=detector_backend,
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
grayscale=grayscale,
human_readable=True,
)
@ -525,7 +533,8 @@ def detectFace(
added to resize the image (default is (224, 224)).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
enforce_detection (boolean): If no face is detected in an image, raise an exception.
Set to False to avoid the exception for low-resolution images (default is True).
@ -538,7 +547,6 @@ def detectFace(
logger.warn("Function detectFace is deprecated. Use extract_faces instead.")
face_objs = extract_faces(
img_path=img_path,
target_size=target_size,
detector_backend=detector_backend,
enforce_detection=enforce_detection,
align=align,
@ -547,4 +555,5 @@ def detectFace(
extracted_face = None
if len(face_objs) > 0:
extracted_face = face_objs[0]["face"]
extracted_face = preprocessing.resize_image(img=extracted_face, target_size=target_size)
return extracted_face

View File

@ -1,8 +1,8 @@
from flask import Blueprint, request
from deepface.api.src.modules.core import service
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="api/src/routes.py")
logger = log.get_singletonish_logger()
blueprint = Blueprint("routes", __name__)

View File

@ -1,10 +1,11 @@
import os
import gdown
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.FacialRecognition import FacialRecognition
logger = Logger(module="basemodels.ArcFace")
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
# pylint: disable=unsubscriptable-object

View File

@ -1,10 +1,10 @@
import os
import gdown
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons import logger as log
logger = Logger(module="basemodels.DeepID")
logger = log.get_singletonish_logger()
tf_version = package_utils.get_tf_major_version()

View File

@ -4,10 +4,10 @@ import bz2
import gdown
import numpy as np
from deepface.commons import folder_utils
from deepface.commons.logger import Logger
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons import logger as log
logger = Logger(module="basemodels.DlibResNet")
logger = log.get_singletonish_logger()
# pylint: disable=too-few-public-methods

View File

@ -1,10 +1,10 @@
import os
import gdown
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons import logger as log
logger = Logger(module="basemodels.Facenet")
logger = log.get_singletonish_logger()
# --------------------------------
# dependency configuration

View File

@ -2,10 +2,10 @@ import os
import zipfile
import gdown
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons import logger as log
logger = Logger(module="basemodels.FbDeepFace")
logger = log.get_singletonish_logger()
# --------------------------------
# dependency configuration

View File

@ -8,9 +8,9 @@ import tensorflow as tf
# project dependencies
from deepface.commons import package_utils, folder_utils
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="basemodels.GhostFaceNet")
logger = log.get_singletonish_logger()
tf_major = package_utils.get_tf_major_version()
if tf_major == 1:

View File

@ -2,10 +2,10 @@ import os
import gdown
import tensorflow as tf
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons import logger as log
logger = Logger(module="basemodels.OpenFace")
logger = log.get_singletonish_logger()
tf_version = package_utils.get_tf_major_version()
if tf_version == 1:

View File

@ -1,15 +1,18 @@
# built-in dependencies
import os
from typing import Any, List
# 3rd party dependencies
import numpy as np
import cv2 as cv
import gdown
# project dependencies
from deepface.commons import folder_utils
from deepface.commons.logger import Logger
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons import logger as log
logger = Logger(module="basemodels.SFace")
logger = log.get_singletonish_logger()
# pylint: disable=line-too-long, too-few-public-methods

View File

@ -5,9 +5,9 @@ import numpy as np
from deepface.commons import package_utils, folder_utils
from deepface.modules import verification
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="basemodels.VGGFace")
logger = log.get_singletonish_logger()
# ---------------------------------------

View File

@ -1,8 +1,8 @@
import os
from pathlib import Path
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="deepface/commons/folder_utils.py")
logger = log.get_singletonish_logger()
def initialize_folder() -> None:

View File

@ -0,0 +1,149 @@
# built-in dependencies
import os
import io
from typing import List, Union, Tuple
import hashlib
import base64
from pathlib import Path
# 3rd party dependencies
import requests
import numpy as np
import cv2
from PIL import Image
def list_images(path: str) -> List[str]:
"""
List images in a given path
Args:
path (str): path's location
Returns:
images (list): list of exact image paths
"""
images = []
for r, _, f in os.walk(path):
for file in f:
exact_path = os.path.join(r, file)
_, ext = os.path.splitext(exact_path)
ext_lower = ext.lower()
if ext_lower not in {".jpg", ".jpeg", ".png"}:
continue
with Image.open(exact_path) as img: # lazy
if img.format.lower() in ["jpeg", "png"]:
images.append(exact_path)
return images
def find_image_hash(file_path: str) -> str:
"""
Find the hash of given image file with its properties
finding the hash of image content is costly operation
Args:
file_path (str): exact image path
Returns:
hash (str): digest with sha1 algorithm
"""
file_stats = os.stat(file_path)
# some properties
file_size = file_stats.st_size
creation_time = file_stats.st_ctime
modification_time = file_stats.st_mtime
properties = f"{file_size}-{creation_time}-{modification_time}"
hasher = hashlib.sha1()
hasher.update(properties.encode("utf-8"))
return hasher.hexdigest()
def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
"""
Load image from path, url, base64 or numpy array.
Args:
img: a path, url, base64 or numpy array.
Returns:
image (numpy array): the loaded image in BGR format
image name (str): image name itself
"""
# The image is already a numpy array
if isinstance(img, np.ndarray):
return img, "numpy array"
if isinstance(img, Path):
img = str(img)
if not isinstance(img, str):
raise ValueError(f"img must be numpy array or str but it is {type(img)}")
# The image is a base64 string
if img.startswith("data:image/"):
return load_image_from_base64(img), "base64 encoded string"
# The image is a url
if img.lower().startswith("http://") or img.lower().startswith("https://"):
return load_image_from_web(url=img), img
# The image is a path
if os.path.isfile(img) is not True:
raise ValueError(f"Confirm that {img} exists")
# image must be a file on the system then
# image name must have english characters
if img.isascii() is False:
raise ValueError(f"Input image must not have non-english characters - {img}")
img_obj_bgr = cv2.imread(img)
# img_obj_rgb = cv2.cvtColor(img_obj_bgr, cv2.COLOR_BGR2RGB)
return img_obj_bgr, img
def load_image_from_base64(uri: str) -> np.ndarray:
"""
Load image from base64 string.
Args:
uri: a base64 string.
Returns:
numpy array: the loaded image.
"""
encoded_data_parts = uri.split(",")
if len(encoded_data_parts) < 2:
raise ValueError("format error in base64 encoded string")
encoded_data = encoded_data_parts[1]
decoded_bytes = base64.b64decode(encoded_data)
# similar to find functionality, we are just considering these extensions
# content type is safer option than file extension
with Image.open(io.BytesIO(decoded_bytes)) as img:
file_type = img.format.lower()
if file_type not in ["jpeg", "png"]:
raise ValueError(f"input image can be jpg or png, but it is {file_type}")
nparr = np.fromstring(decoded_bytes, np.uint8)
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
return img_bgr
def load_image_from_web(url: str) -> np.ndarray:
"""
Loading an image from web
Args:
url: link for the image
Returns:
img (np.ndarray): equivalent to pre-loaded image from opencv (BGR format)
"""
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
image_array = np.asarray(bytearray(response.raw.read()), dtype=np.uint8)
img = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
return img

View File

@ -39,3 +39,16 @@ class Logger:
def dump_log(self, message):
print(f"{str(datetime.now())[2:-7]} - {message}")
def get_singletonish_logger():
# singleton design pattern
global model_obj
if not "model_obj" in globals():
model_obj = {}
if "logger" not in model_obj.keys():
model_obj["logger"] = Logger(module="Singleton")
return model_obj["logger"]

View File

@ -1,14 +1,10 @@
# built-in dependencies
import os
import hashlib
# 3rd party dependencies
import tensorflow as tf
# package dependencies
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="commons.package_utils")
logger = log.get_singletonish_logger()
def get_tf_major_version() -> int:
@ -29,29 +25,6 @@ def get_tf_minor_version() -> int:
return int(tf.__version__.split(".", maxsplit=-1)[1])
def find_hash_of_file(file_path: str) -> str:
"""
Find the hash of given image file with its properties
finding the hash of image content is costly operation
Args:
file_path (str): exact image path
Returns:
hash (str): digest with sha1 algorithm
"""
file_stats = os.stat(file_path)
# some properties
file_size = file_stats.st_size
creation_time = file_stats.st_ctime
modification_time = file_stats.st_mtime
properties = f"{file_size}-{creation_time}-{modification_time}"
hasher = hashlib.sha1()
hasher.update(properties.encode("utf-8"))
return hasher.hexdigest()
def validate_for_keras3():
tf_major = get_tf_major_version()
tf_minor = get_tf_minor_version()

View File

@ -0,0 +1,217 @@
# built-in dependencies
import os
from typing import List
# 3rd party dependencies
import numpy as np
import cv2
import gdown
# project dependencies
from deepface.commons import folder_utils
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
# pylint: disable=c-extension-no-member
WEIGHTS_URL = "https://github.com/Star-Clouds/CenterFace/raw/master/models/onnx/centerface.onnx"
class CenterFaceClient(Detector):
def __init__(self):
# BUG: model must be flushed for each call
# self.model = self.build_model()
pass
def build_model(self):
"""
Download pre-trained weights of CenterFace model if necessary and load built model
"""
weights_path = f"{folder_utils.get_deepface_home()}/.deepface/weights/centerface.onnx"
if not os.path.isfile(weights_path):
logger.info(f"Downloading CenterFace weights from {WEIGHTS_URL} to {weights_path}...")
try:
gdown.download(WEIGHTS_URL, weights_path, quiet=False)
except Exception as err:
raise ValueError(
f"Exception while downloading CenterFace weights from {WEIGHTS_URL}."
f"You may consider to download it to {weights_path} manually."
) from err
logger.info(f"CenterFace model is just downloaded to {os.path.basename(weights_path)}")
return CenterFace(weight_path=weights_path)
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]:
"""
Detect and align face with CenterFace
Args:
img (np.ndarray): pre-loaded image as numpy array
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
resp = []
threshold = float(os.getenv("CENTERFACE_THRESHOLD", "0.80"))
# BUG: model causes problematic results from 2nd call if it is not flushed
# detections, landmarks = self.model.forward(
# img, img.shape[0], img.shape[1], threshold=threshold
# )
detections, landmarks = self.build_model().forward(
img, img.shape[0], img.shape[1], threshold=threshold
)
for i, detection in enumerate(detections):
boxes, confidence = detection[:4], detection[4]
x = boxes[0]
y = boxes[1]
w = boxes[2] - x
h = boxes[3] - y
landmark = landmarks[i]
right_eye = (int(landmark[0]), int(landmark[1]))
left_eye = (int(landmark[2]), int(landmark[3]))
# nose = (int(landmark[4]), int(landmark [5]))
# mouth_right = (int(landmark[6]), int(landmark [7]))
# mouth_left = (int(landmark[8]), int(landmark [9]))
facial_area = FacialAreaRegion(
x=int(x),
y=int(y),
w=int(w),
h=int(h),
left_eye=left_eye,
right_eye=right_eye,
confidence=min(max(0, float(confidence)), 1.0),
)
resp.append(facial_area)
return resp
class CenterFace:
"""
This class is heavily inspired from
github.com/Star-Clouds/CenterFace/blob/master/prj-python/centerface.py
"""
def __init__(self, weight_path: str):
self.net = cv2.dnn.readNetFromONNX(weight_path)
self.img_h_new, self.img_w_new, self.scale_h, self.scale_w = 0, 0, 0, 0
def forward(self, img, height, width, threshold=0.5):
self.img_h_new, self.img_w_new, self.scale_h, self.scale_w = self.transform(height, width)
return self.inference_opencv(img, threshold)
def inference_opencv(self, img, threshold):
blob = cv2.dnn.blobFromImage(
img,
scalefactor=1.0,
size=(self.img_w_new, self.img_h_new),
mean=(0, 0, 0),
swapRB=True,
crop=False,
)
self.net.setInput(blob)
heatmap, scale, offset, lms = self.net.forward(["537", "538", "539", "540"])
return self.postprocess(heatmap, lms, offset, scale, threshold)
def transform(self, h, w):
img_h_new, img_w_new = int(np.ceil(h / 32) * 32), int(np.ceil(w / 32) * 32)
scale_h, scale_w = img_h_new / h, img_w_new / w
return img_h_new, img_w_new, scale_h, scale_w
def postprocess(self, heatmap, lms, offset, scale, threshold):
dets, lms = self.decode(
heatmap, scale, offset, lms, (self.img_h_new, self.img_w_new), threshold=threshold
)
if len(dets) > 0:
dets[:, 0:4:2], dets[:, 1:4:2] = (
dets[:, 0:4:2] / self.scale_w,
dets[:, 1:4:2] / self.scale_h,
)
lms[:, 0:10:2], lms[:, 1:10:2] = (
lms[:, 0:10:2] / self.scale_w,
lms[:, 1:10:2] / self.scale_h,
)
else:
dets = np.empty(shape=[0, 5], dtype=np.float32)
lms = np.empty(shape=[0, 10], dtype=np.float32)
return dets, lms
def decode(self, heatmap, scale, offset, landmark, size, threshold=0.1):
heatmap = np.squeeze(heatmap)
scale0, scale1 = scale[0, 0, :, :], scale[0, 1, :, :]
offset0, offset1 = offset[0, 0, :, :], offset[0, 1, :, :]
c0, c1 = np.where(heatmap > threshold)
boxes, lms = [], []
if len(c0) > 0:
# pylint:disable=consider-using-enumerate
for i in range(len(c0)):
s0, s1 = np.exp(scale0[c0[i], c1[i]]) * 4, np.exp(scale1[c0[i], c1[i]]) * 4
o0, o1 = offset0[c0[i], c1[i]], offset1[c0[i], c1[i]]
s = heatmap[c0[i], c1[i]]
x1, y1 = max(0, (c1[i] + o1 + 0.5) * 4 - s1 / 2), max(
0, (c0[i] + o0 + 0.5) * 4 - s0 / 2
)
x1, y1 = min(x1, size[1]), min(y1, size[0])
boxes.append([x1, y1, min(x1 + s1, size[1]), min(y1 + s0, size[0]), s])
lm = []
for j in range(5):
lm.append(landmark[0, j * 2 + 1, c0[i], c1[i]] * s1 + x1)
lm.append(landmark[0, j * 2, c0[i], c1[i]] * s0 + y1)
lms.append(lm)
boxes = np.asarray(boxes, dtype=np.float32)
keep = self.nms(boxes[:, :4], boxes[:, 4], 0.3)
boxes = boxes[keep, :]
lms = np.asarray(lms, dtype=np.float32)
lms = lms[keep, :]
return boxes, lms
def nms(self, boxes, scores, nms_thresh):
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = np.argsort(scores)[::-1]
num_detections = boxes.shape[0]
suppressed = np.zeros((num_detections,), dtype=bool)
keep = []
for _i in range(num_detections):
i = order[_i]
if suppressed[i]:
continue
keep.append(i)
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, num_detections):
j = order[_j]
if suppressed[j]:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0, xx2 - xx1 + 1)
h = max(0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= nms_thresh:
suppressed[j] = True
return keep

View File

@ -12,10 +12,11 @@ from deepface.detectors import (
Ssd,
Yolo,
YuNet,
CenterFace,
)
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="deepface/detectors/DetectorWrapper.py")
logger = log.get_singletonish_logger()
def build_model(detector_backend: str) -> Any:
@ -38,6 +39,7 @@ def build_model(detector_backend: str) -> Any:
"yolov8": Yolo.YoloClient,
"yunet": YuNet.YuNetClient,
"fastmtcnn": FastMtCnn.FastMtCnnClient,
"centerface": CenterFace.CenterFaceClient,
}
if not "face_detector_obj" in globals():
@ -93,7 +95,7 @@ def detect_faces(
expand_percentage = 0
# find facial areas of given image
facial_areas = face_detector.detect_faces(img=img)
facial_areas = face_detector.detect_faces(img)
results = []
for facial_area in facial_areas:
@ -173,22 +175,30 @@ def rotate_facial_area(
# Angle in radians
angle = angle * np.pi / 180
height, weight = size
# Translate the facial area to the center of the image
x = (facial_area[0] + facial_area[2]) / 2 - size[1] / 2
y = (facial_area[1] + facial_area[3]) / 2 - size[0] / 2
x = (facial_area[0] + facial_area[2]) / 2 - weight / 2
y = (facial_area[1] + facial_area[3]) / 2 - height / 2
# Rotate the facial area
x_new = x * np.cos(angle) + y * direction * np.sin(angle)
y_new = -x * direction * np.sin(angle) + y * np.cos(angle)
# Translate the facial area back to the original position
x_new = x_new + size[1] / 2
y_new = y_new + size[0] / 2
x_new = x_new + weight / 2
y_new = y_new + height / 2
# Calculate the new facial area
# Calculate projected coordinates after alignment
x1 = x_new - (facial_area[2] - facial_area[0]) / 2
y1 = y_new - (facial_area[3] - facial_area[1]) / 2
x2 = x_new + (facial_area[2] - facial_area[0]) / 2
y2 = y_new + (facial_area[3] - facial_area[1]) / 2
return (int(x1), int(y1), int(x2), int(y2))
# validate projected coordinates are in image's boundaries
x1 = max(int(x1), 0)
y1 = max(int(y1), 0)
x2 = min(int(x2), weight)
y2 = min(int(y2), height)
return (x1, y1, x2, y2)

View File

@ -5,9 +5,9 @@ import gdown
import numpy as np
from deepface.commons import folder_utils
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="detectors.DlibWrapper")
logger = log.get_singletonish_logger()
class DlibClient(Detector):

View File

@ -62,13 +62,16 @@ class FastMtCnnClient(Detector):
# this is not a must dependency. do not import it in the global level.
try:
from facenet_pytorch import MTCNN as fast_mtcnn
import torch
except ModuleNotFoundError as e:
raise ImportError(
"FastMtcnn is an optional detector, ensure the library is installed."
"Please install using 'pip install facenet-pytorch' "
) from e
face_detector = fast_mtcnn(device="cpu")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
face_detector = fast_mtcnn(device=device)
return face_detector

View File

@ -7,9 +7,9 @@ import numpy as np
from deepface.detectors import OpenCv
from deepface.commons import folder_utils
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="detectors.SsdWrapper")
logger = log.get_singletonish_logger()
# pylint: disable=line-too-long, c-extension-no-member

View File

@ -4,9 +4,9 @@ import numpy as np
import gdown
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons import folder_utils
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger()
logger = log.get_singletonish_logger()
# Model's weights paths
PATH = "/.deepface/weights/yolov8n-face.pt"
@ -14,10 +14,6 @@ PATH = "/.deepface/weights/yolov8n-face.pt"
# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB
WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb"
# Confidence thresholds for landmarks detection
# used in alignment_procedure function
LANDMARKS_CONFIDENCE_THRESHOLD = 0.5
class YoloClient(Detector):
def __init__(self):

View File

@ -1,13 +1,18 @@
# built-in dependencies
import os
from typing import Any, List
# 3rd party dependencies
import cv2
import numpy as np
import gdown
# project dependencies
from deepface.commons import folder_utils
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="detectors.YunetWrapper")
logger = log.get_singletonish_logger()
class YuNetClient(Detector):

View File

@ -3,10 +3,10 @@ import gdown
import numpy as np
from deepface.basemodels import VGGFace
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.Demography import Demography
from deepface.commons import logger as log
logger = Logger(module="extendedmodels.Age")
logger = log.get_singletonish_logger()
# ----------------------------------------
# dependency configurations

View File

@ -1,12 +1,17 @@
# built-in dependencies
import os
# 3rd party dependencies
import gdown
import numpy as np
import cv2
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.Demography import Demography
logger = Logger(module="extendedmodels.Emotion")
# project dependencies
from deepface.commons import package_utils, folder_utils
from deepface.models.Demography import Demography
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
# -------------------------------------------
# pylint: disable=line-too-long

View File

@ -1,12 +1,17 @@
# built-in dependencies
import os
# 3rd party dependencies
import gdown
import numpy as np
# project dependencies
from deepface.basemodels import VGGFace
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.Demography import Demography
from deepface.commons import logger as log
logger = Logger(module="extendedmodels.Gender")
logger = log.get_singletonish_logger()
# -------------------------------------
# pylint: disable=line-too-long

View File

@ -1,12 +1,17 @@
# built-in dependencies
import os
# 3rd party dependencies
import gdown
import numpy as np
# project dependencies
from deepface.basemodels import VGGFace
from deepface.commons import package_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.Demography import Demography
from deepface.commons import logger as log
logger = Logger(module="extendedmodels.Race")
logger = log.get_singletonish_logger()
# --------------------------
# pylint: disable=line-too-long

View File

@ -6,7 +6,7 @@ import numpy as np
from tqdm import tqdm
# project dependencies
from deepface.modules import modeling, detection
from deepface.modules import modeling, detection, preprocessing
from deepface.extendedmodels import Gender, Race, Emotion
@ -34,7 +34,8 @@ def analyze(
Set to False to avoid the exception for low-resolution images (default is True).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
@ -118,7 +119,6 @@ def analyze(
img_objs = detection.extract_faces(
img_path=img_path,
target_size=(224, 224),
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
@ -130,60 +130,68 @@ def analyze(
img_content = img_obj["face"]
img_region = img_obj["facial_area"]
img_confidence = img_obj["confidence"]
if img_content.shape[0] > 0 and img_content.shape[1] > 0:
obj = {}
# facial attribute analysis
pbar = tqdm(
range(0, len(actions)),
desc="Finding actions",
disable=silent if len(actions) > 1 else True,
)
for index in pbar:
action = actions[index]
pbar.set_description(f"Action: {action}")
if img_content.shape[0] == 0 or img_content.shape[1] == 0:
continue
if action == "emotion":
emotion_predictions = modeling.build_model("Emotion").predict(img_content)
sum_of_predictions = emotion_predictions.sum()
# rgb to bgr
img_content = img_content[:, :, ::-1]
obj["emotion"] = {}
for i, emotion_label in enumerate(Emotion.labels):
emotion_prediction = 100 * emotion_predictions[i] / sum_of_predictions
obj["emotion"][emotion_label] = emotion_prediction
# resize input image
img_content = preprocessing.resize_image(img=img_content, target_size=(224, 224))
obj["dominant_emotion"] = Emotion.labels[np.argmax(emotion_predictions)]
obj = {}
# facial attribute analysis
pbar = tqdm(
range(0, len(actions)),
desc="Finding actions",
disable=silent if len(actions) > 1 else True,
)
for index in pbar:
action = actions[index]
pbar.set_description(f"Action: {action}")
elif action == "age":
apparent_age = modeling.build_model("Age").predict(img_content)
# int cast is for exception - object of type 'float32' is not JSON serializable
obj["age"] = int(apparent_age)
if action == "emotion":
emotion_predictions = modeling.build_model("Emotion").predict(img_content)
sum_of_predictions = emotion_predictions.sum()
elif action == "gender":
gender_predictions = modeling.build_model("Gender").predict(img_content)
obj["gender"] = {}
for i, gender_label in enumerate(Gender.labels):
gender_prediction = 100 * gender_predictions[i]
obj["gender"][gender_label] = gender_prediction
obj["emotion"] = {}
for i, emotion_label in enumerate(Emotion.labels):
emotion_prediction = 100 * emotion_predictions[i] / sum_of_predictions
obj["emotion"][emotion_label] = emotion_prediction
obj["dominant_gender"] = Gender.labels[np.argmax(gender_predictions)]
obj["dominant_emotion"] = Emotion.labels[np.argmax(emotion_predictions)]
elif action == "race":
race_predictions = modeling.build_model("Race").predict(img_content)
sum_of_predictions = race_predictions.sum()
elif action == "age":
apparent_age = modeling.build_model("Age").predict(img_content)
# int cast is for exception - object of type 'float32' is not JSON serializable
obj["age"] = int(apparent_age)
obj["race"] = {}
for i, race_label in enumerate(Race.labels):
race_prediction = 100 * race_predictions[i] / sum_of_predictions
obj["race"][race_label] = race_prediction
elif action == "gender":
gender_predictions = modeling.build_model("Gender").predict(img_content)
obj["gender"] = {}
for i, gender_label in enumerate(Gender.labels):
gender_prediction = 100 * gender_predictions[i]
obj["gender"][gender_label] = gender_prediction
obj["dominant_race"] = Race.labels[np.argmax(race_predictions)]
obj["dominant_gender"] = Gender.labels[np.argmax(gender_predictions)]
# -----------------------------
# mention facial areas
obj["region"] = img_region
# include image confidence
obj["face_confidence"] = img_confidence
elif action == "race":
race_predictions = modeling.build_model("Race").predict(img_content)
sum_of_predictions = race_predictions.sum()
resp_objects.append(obj)
obj["race"] = {}
for i, race_label in enumerate(Race.labels):
race_prediction = 100 * race_predictions[i] / sum_of_predictions
obj["race"][race_label] = race_prediction
obj["dominant_race"] = Race.labels[np.argmax(race_predictions)]
# -----------------------------
# mention facial areas
obj["region"] = img_region
# include image confidence
obj["face_confidence"] = img_confidence
resp_objects.append(obj)
return resp_objects

View File

@ -1,5 +1,5 @@
# built-in dependencies
from typing import Any, Dict, List, Tuple, Union, Optional
from typing import Any, Dict, List, Tuple, Union
# 3rd part dependencies
import numpy as np
@ -7,33 +7,23 @@ import cv2
from PIL import Image
# project dependencies
from deepface.modules import preprocessing
from deepface.models.Detector import DetectedFace, FacialAreaRegion
from deepface.detectors import DetectorWrapper
from deepface.commons import package_utils
from deepface.commons.logger import Logger
from deepface.commons import image_utils
from deepface.commons import logger as log
logger = Logger(module="deepface/modules/detection.py")
logger = log.get_singletonish_logger()
# pylint: disable=no-else-raise
tf_major_version = package_utils.get_tf_major_version()
if tf_major_version == 1:
from keras.preprocessing import image
elif tf_major_version == 2:
from tensorflow.keras.preprocessing import image
def extract_faces(
img_path: Union[str, np.ndarray],
target_size: Optional[Tuple[int, int]] = (224, 224),
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,
expand_percentage: int = 0,
grayscale: bool = False,
human_readable=False,
) -> List[Dict[str, Any]]:
"""
Extract faces from a given image
@ -42,11 +32,9 @@ def extract_faces(
img_path (str or np.ndarray): Path to the first image. Accepts exact image path
as a string, numpy array (BGR), or base64 encoded images.
target_size (tuple): final shape of facial image. black pixels will be
added to resize the image.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv)
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv)
enforce_detection (boolean): If no face is detected in an image, raise an exception.
Default is True. Set to False to avoid the exception for low-resolution images.
@ -58,13 +46,10 @@ def extract_faces(
grayscale (boolean): Flag to convert the image to grayscale before
processing (default is False).
human_readable (bool): Flag to make the image human readable. 3D RGB for human readable
or 4D BGR for ML models (default is False).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
- "face" (np.ndarray): The detected face as a NumPy array.
- "face" (np.ndarray): The detected face as a NumPy array in RGB format.
- "facial_area" (Dict[str, Any]): The detected face's regions as a dictionary containing:
- keys 'x', 'y', 'w', 'h' with int values
@ -78,7 +63,7 @@ def extract_faces(
resp_objs = []
# img might be path, base64 or numpy array. Convert it to numpy whatever it is.
img, img_name = preprocessing.load_image(img_path)
img, img_name = image_utils.load_image(img_path)
if img is None:
raise ValueError(f"Exception while loading {img_name}")
@ -122,57 +107,11 @@ def extract_faces(
if grayscale is True:
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
# resize and padding
if target_size is not None:
factor_0 = target_size[0] / current_img.shape[0]
factor_1 = target_size[1] / current_img.shape[1]
factor = min(factor_0, factor_1)
dsize = (
int(current_img.shape[1] * factor),
int(current_img.shape[0] * factor),
)
current_img = cv2.resize(current_img, dsize)
diff_0 = target_size[0] - current_img.shape[0]
diff_1 = target_size[1] - current_img.shape[1]
if grayscale is False:
# Put the base image in the middle of the padded image
current_img = np.pad(
current_img,
(
(diff_0 // 2, diff_0 - diff_0 // 2),
(diff_1 // 2, diff_1 - diff_1 // 2),
(0, 0),
),
"constant",
)
else:
current_img = np.pad(
current_img,
(
(diff_0 // 2, diff_0 - diff_0 // 2),
(diff_1 // 2, diff_1 - diff_1 // 2),
),
"constant",
)
# double check: if target image is not still the same size with target.
if current_img.shape[0:2] != target_size:
current_img = cv2.resize(current_img, target_size)
# normalizing the image pixels
# what this line doing? must?
img_pixels = image.img_to_array(current_img)
img_pixels = np.expand_dims(img_pixels, axis=0)
img_pixels /= 255 # normalize input in [0, 1]
# discard expanded dimension
if human_readable is True and len(img_pixels.shape) == 4:
img_pixels = img_pixels[0]
current_img = current_img / 255 # normalize input in [0, 1]
resp_objs.append(
{
"face": img_pixels[:, :, ::-1] if human_readable is True else img_pixels,
"face": current_img[:, :, ::-1],
"facial_area": {
"x": int(current_region.x),
"y": int(current_region.y),

View File

@ -1,100 +1,19 @@
import os
from typing import Union, Tuple
import base64
from pathlib import Path
# built-in dependencies
from typing import Tuple
# 3rd party
import numpy as np
import cv2
import requests
# project dependencies
from deepface.commons import package_utils
def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
"""
Load image from path, url, base64 or numpy array.
Args:
img: a path, url, base64 or numpy array.
Returns:
image (numpy array): the loaded image in BGR format
image name (str): image name itself
"""
# The image is already a numpy array
if isinstance(img, np.ndarray):
return img, "numpy array"
if isinstance(img, Path):
img = str(img)
if not isinstance(img, str):
raise ValueError(f"img must be numpy array or str but it is {type(img)}")
# The image is a base64 string
if img.startswith("data:image/"):
return load_base64(img), "base64 encoded string"
# The image is a url
if img.lower().startswith("http://") or img.lower().startswith("https://"):
return load_image_from_web(url=img), img
# The image is a path
if os.path.isfile(img) is not True:
raise ValueError(f"Confirm that {img} exists")
# image must be a file on the system then
# image name must have english characters
if img.isascii() is False:
raise ValueError(f"Input image must not have non-english characters - {img}")
img_obj_bgr = cv2.imread(img)
# img_obj_rgb = cv2.cvtColor(img_obj_bgr, cv2.COLOR_BGR2RGB)
return img_obj_bgr, img
def load_image_from_web(url: str) -> np.ndarray:
"""
Loading an image from web
Args:
url: link for the image
Returns:
img (np.ndarray): equivalent to pre-loaded image from opencv (BGR format)
"""
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
image_array = np.asarray(bytearray(response.raw.read()), dtype=np.uint8)
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
return image
def load_base64(uri: str) -> np.ndarray:
"""Load image from base64 string.
Args:
uri: a base64 string.
Returns:
numpy array: the loaded image.
"""
encoded_data_parts = uri.split(",")
if len(encoded_data_parts) < 2:
raise ValueError("format error in base64 encoded string")
# similar to find functionality, we are just considering these extensions
if not (
uri.startswith("data:image/jpeg")
or uri.startswith("data:image/jpg")
or uri.startswith("data:image/png")
):
raise ValueError(f"input image can be jpg, jpeg or png, but it is {encoded_data_parts}")
encoded_data = encoded_data_parts[1]
nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8)
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
return img_bgr
tf_major_version = package_utils.get_tf_major_version()
if tf_major_version == 1:
from keras.preprocessing import image
elif tf_major_version == 2:
from tensorflow.keras.preprocessing import image
def normalize_input(img: np.ndarray, normalization: str = "base") -> np.ndarray:
@ -153,3 +72,50 @@ def normalize_input(img: np.ndarray, normalization: str = "base") -> np.ndarray:
raise ValueError(f"unimplemented normalization type - {normalization}")
return img
def resize_image(img: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
"""
Resize an image to expected size of a ml model with adding black pixels.
Args:
img (np.ndarray): pre-loaded image as numpy array
target_size (tuple): input shape of ml model
Returns:
img (np.ndarray): resized input image
"""
factor_0 = target_size[0] / img.shape[0]
factor_1 = target_size[1] / img.shape[1]
factor = min(factor_0, factor_1)
dsize = (
int(img.shape[1] * factor),
int(img.shape[0] * factor),
)
img = cv2.resize(img, dsize)
diff_0 = target_size[0] - img.shape[0]
diff_1 = target_size[1] - img.shape[1]
# Put the base image in the middle of the padded image
img = np.pad(
img,
(
(diff_0 // 2, diff_0 - diff_0 // 2),
(diff_1 // 2, diff_1 - diff_1 // 2),
(0, 0),
),
"constant",
)
# double check: if target image is not still the same size with target.
if img.shape[0:2] != target_size:
img = cv2.resize(img, target_size)
# make it 4-dimensional how ML models expect
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
if img.max() > 1:
img = (img.astype(np.float32) / 255.0).astype(np.float32)
return img

View File

@ -10,12 +10,11 @@ import pandas as pd
from tqdm import tqdm
# project dependencies
from deepface.commons.logger import Logger
from deepface.commons import package_utils
from deepface.modules import representation, detection, modeling, verification
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons import image_utils
from deepface.modules import representation, detection, verification
from deepface.commons import logger as log
logger = Logger(module="deepface/modules/recognition.py")
logger = log.get_singletonish_logger()
def find(
@ -52,7 +51,7 @@ def find(
Default is True. Set to False to avoid the exception for low-resolution images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8'.
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
align (boolean): Perform alignment based on the eye positions.
@ -89,17 +88,25 @@ def find(
tic = time.time()
# -------------------------------
if os.path.isdir(db_path) is not True:
raise ValueError("Passed db_path does not exist!")
model: FacialRecognition = modeling.build_model(model_name)
target_size = model.input_shape
file_parts = [
"ds",
"model",
model_name,
"detector",
detector_backend,
"aligned" if align else "unaligned",
"normalization",
normalization,
"expand",
str(expand_percentage),
]
# ---------------------------------------
file_name = f"ds_{model_name}_{detector_backend}_v2.pkl"
file_name = "_".join(file_parts) + ".pkl"
file_name = file_name.replace("-", "").lower()
datastore_path = os.path.join(db_path, file_name)
representations = []
@ -136,7 +143,7 @@ def find(
pickled_images = [representation["identity"] for representation in representations]
# Get the list of images on storage
storage_images = __list_images(path=db_path)
storage_images = image_utils.list_images(path=db_path)
if len(storage_images) == 0:
raise ValueError(f"No item found in {db_path}")
@ -153,7 +160,7 @@ def find(
if identity in old_images:
continue
alpha_hash = current_representation["hash"]
beta_hash = package_utils.find_hash_of_file(identity)
beta_hash = image_utils.find_image_hash(identity)
if alpha_hash != beta_hash:
logger.debug(f"Even though {identity} represented before, it's replaced later.")
replaced_images.append(identity)
@ -179,10 +186,10 @@ def find(
representations += __find_bulk_embeddings(
employees=new_images,
model_name=model_name,
target_size=target_size,
detector_backend=detector_backend,
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
normalization=normalization,
silent=silent,
) # add new images
@ -211,7 +218,6 @@ def find(
# img path might have more than once face
source_objs = detection.extract_faces(
img_path=img_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
@ -285,27 +291,9 @@ def find(
return resp_obj
def __list_images(path: str) -> List[str]:
"""
List images in a given path
Args:
path (str): path's location
Returns:
images (list): list of exact image paths
"""
images = []
for r, _, f in os.walk(path):
for file in f:
if file.lower().endswith((".jpg", ".jpeg", ".png")):
exact_path = os.path.join(r, file)
images.append(exact_path)
return images
def __find_bulk_embeddings(
employees: List[str],
model_name: str = "VGG-Face",
target_size: tuple = (224, 224),
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,
@ -322,8 +310,6 @@ def __find_bulk_embeddings(
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
target_size (tuple): expected input shape of facial recognition model
detector_backend (str): face detector model name
enforce_detection (bool): set this to False if you
@ -348,12 +334,11 @@ def __find_bulk_embeddings(
desc="Finding representations",
disable=silent,
):
file_hash = package_utils.find_hash_of_file(employee)
file_hash = image_utils.find_image_hash(employee)
try:
img_objs = detection.extract_faces(
img_path=employee,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,

View File

@ -3,9 +3,9 @@ from typing import Any, Dict, List, Union
# 3rd party dependencies
import numpy as np
import cv2
# project dependencies
from deepface.commons import image_utils
from deepface.modules import modeling, detection, preprocessing
from deepface.models.FacialRecognition import FacialRecognition
@ -34,7 +34,7 @@ def represent(
Default is True. Set to False to avoid the exception for low-resolution images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8'.
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
align (boolean): Perform alignment based on the eye positions.
@ -67,7 +67,6 @@ def represent(
if detector_backend != "skip":
img_objs = detection.extract_faces(
img_path=img_path,
target_size=(target_size[1], target_size[0]),
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
@ -76,17 +75,11 @@ def represent(
)
else: # skip
# Try load. If load error, will raise exception internal
img, _ = preprocessing.load_image(img_path)
# --------------------------------
if len(img.shape) == 4:
img = img[0] # e.g. (1, 224, 224, 3) to (224, 224, 3)
if len(img.shape) == 3:
img = cv2.resize(img, target_size)
img = np.expand_dims(img, axis=0)
# when called from verify, this is already normalized. But needed when user given.
if img.max() > 1:
img = (img.astype(np.float32) / 255.0).astype(np.float32)
# --------------------------------
img, _ = image_utils.load_image(img_path)
if len(img.shape) != 3:
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
# make dummy region and confidence to keep compatibility with `extract_faces`
img_objs = [
{
@ -99,8 +92,20 @@ def represent(
for img_obj in img_objs:
img = img_obj["face"]
# rgb to bgr
img = img[:, :, ::-1]
region = img_obj["facial_area"]
confidence = img_obj["confidence"]
# resize to expected shape of ml model
img = preprocessing.resize_image(
img=img,
# thanks to DeepId (!)
target_size=(target_size[1], target_size[0]),
)
# custom normalization
img = preprocessing.normalize_input(img=img, normalization=normalization)

View File

@ -10,10 +10,9 @@ import cv2
# project dependencies
from deepface import DeepFace
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="commons.realtime")
logger = log.get_singletonish_logger()
# dependency configuration
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
@ -44,7 +43,8 @@ def analysis(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
@ -62,7 +62,7 @@ def analysis(
"""
# initialize models
build_demography_models(enable_face_analysis=enable_face_analysis)
target_size = build_facial_recognition_model(model_name=model_name)
build_facial_recognition_model(model_name=model_name)
# call a dummy find function for db_path once to create embeddings before starting webcam
_ = search_identity(
detected_face=np.zeros([224, 224, 3]),
@ -89,9 +89,7 @@ def analysis(
faces_coordinates = []
if freeze is False:
faces_coordinates = grab_facial_areas(
img=img, detector_backend=detector_backend, target_size=target_size
)
faces_coordinates = grab_facial_areas(img=img, detector_backend=detector_backend)
# we will pass img to analyze modules (identity, demography) and add some illustrations
# that is why, we will not be able to extract detected face from img clearly
@ -156,7 +154,7 @@ def analysis(
cv2.destroyAllWindows()
def build_facial_recognition_model(model_name: str) -> tuple:
def build_facial_recognition_model(model_name: str) -> None:
"""
Build facial recognition model
Args:
@ -165,9 +163,8 @@ def build_facial_recognition_model(model_name: str) -> tuple:
Returns
input_shape (tuple): input shape of given facial recognitio n model.
"""
model: FacialRecognition = DeepFace.build_model(model_name=model_name)
_ = DeepFace.build_model(model_name=model_name)
logger.info(f"{model_name} is built")
return model.input_shape
def search_identity(
@ -186,7 +183,8 @@ def search_identity(
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
Returns:
@ -231,7 +229,6 @@ def search_identity(
# load found identity image - extracted if possible
target_objs = DeepFace.extract_faces(
img_path=target_path,
target_size=(IDENTIFIED_IMG_SIZE, IDENTIFIED_IMG_SIZE),
detector_backend=detector_backend,
enforce_detection=False,
align=True,
@ -243,6 +240,7 @@ def search_identity(
# extract 1st item directly
target_obj = target_objs[0]
target_img = target_obj["face"]
target_img = cv2.resize(target_img, (IDENTIFIED_IMG_SIZE, IDENTIFIED_IMG_SIZE))
target_img *= 255
target_img = target_img[:, :, ::-1]
else:
@ -346,15 +344,15 @@ def countdown_to_release(
def grab_facial_areas(
img: np.ndarray, detector_backend: str, target_size: Tuple[int, int], threshold: int = 130
img: np.ndarray, detector_backend: str, threshold: int = 130
) -> List[Tuple[int, int, int, int]]:
"""
Find facial area coordinates in the given image
Args:
img (np.ndarray): image itself
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
target_size (tuple): input shape of the facial recognition model.
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
threshold (int): threshold for facial area, discard smaller ones
Returns
result (list): list of tuple with x, y, w and h coordinates
@ -363,7 +361,6 @@ def grab_facial_areas(
face_objs = DeepFace.extract_faces(
img_path=img,
detector_backend=detector_backend,
target_size=target_size,
# you may consider to extract with larger expanding value
expand_percentage=0,
)
@ -420,7 +417,8 @@ def perform_facial_recognition(
db_path (string): Path to the folder containing image files. All detected faces
in the database will be considered in the decision-making process.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,

View File

@ -1,6 +1,6 @@
# built-in dependencies
import time
from typing import Any, Dict, Union, List, Tuple
from typing import Any, Dict, Optional, Union, List, Tuple
# 3rd party dependencies
import numpy as np
@ -8,9 +8,9 @@ import numpy as np
# project dependencies
from deepface.modules import representation, detection, modeling
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger(module="deepface/modules/verification.py")
logger = log.get_singletonish_logger()
def verify(
@ -24,6 +24,7 @@ def verify(
expand_percentage: int = 0,
normalization: str = "base",
silent: bool = False,
threshold: Optional[float] = None,
) -> Dict[str, Any]:
"""
Verify if an image pair represents the same person or different persons.
@ -45,7 +46,8 @@ def verify(
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv)
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'
(default is opencv)
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2' (default is cosine).
@ -62,6 +64,11 @@ def verify(
silent (boolean): Suppress or allow some log messages for a quieter analysis process
(default is False).
threshold (float): Specify a threshold to determine whether a pair represents the same
person or different individuals. This threshold is used for comparing distances.
If left unset, default pre-tuned threshold values will be applied based on the specified
model name and distance metric (default is None).
Returns:
result (dict): A dictionary containing verification results.
@ -185,7 +192,7 @@ def verify(
)
# find the face pair with minimum distance
threshold = find_threshold(model_name, distance_metric)
threshold = threshold or find_threshold(model_name, distance_metric)
distance = float(min(distances)) # best distance
facial_areas = facial_areas[np.argmin(distances)]
@ -223,12 +230,8 @@ def __extract_faces_and_embeddings(
embeddings = []
facial_areas = []
model: FacialRecognition = modeling.build_model(model_name)
target_size = model.input_shape
img_objs = detection.extract_faces(
img_path=img_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,

Binary file not shown.

After

Width:  |  Height:  |  Size: 490 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 228 KiB

View File

@ -1,3 +1,4 @@
requests>=2.27.1
numpy>=1.14.0
pandas>=0.23.4
gdown>=3.10.1

View File

@ -1,11 +1,15 @@
# 3rd party dependencies
import matplotlib.pyplot as plt
import numpy as np
import cv2
# project dependencies
from deepface import DeepFace
from deepface.modules import verification
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
from deepface.commons import logger as log
logger = Logger()
logger = log.get_singletonish_logger()
# ----------------------------------------------
# build face recognition model
@ -21,11 +25,13 @@ logger.info(f"target_size: {target_size}")
# ----------------------------------------------
# load images and find embeddings
img1 = DeepFace.extract_faces(img_path="dataset/img1.jpg", target_size=target_size)[0]["face"]
img1 = DeepFace.extract_faces(img_path="dataset/img1.jpg")[0]["face"]
img1 = cv2.resize(img1, target_size)
img1 = np.expand_dims(img1, axis=0) # to (1, 224, 224, 3)
img1_representation = model.forward(img1)
img2 = DeepFace.extract_faces(img_path="dataset/img3.jpg", target_size=target_size)[0]["face"]
img2 = DeepFace.extract_faces(img_path="dataset/img3.jpg")[0]["face"]
img2 = cv2.resize(img2, target_size)
img2 = np.expand_dims(img2, axis=0)
img2_representation = model.forward(img2)

View File

@ -1,5 +1,8 @@
# 3rd party dependencies
import cv2
import matplotlib.pyplot as plt
# project dependencies
from deepface.modules import streaming
from deepface import DeepFace
@ -7,9 +10,11 @@ img_path = "dataset/img1.jpg"
img = cv2.imread(img_path)
overlay_img_path = "dataset/img6.jpg"
face_objs = DeepFace.extract_faces(overlay_img_path, target_size=(112, 112))
face_objs = DeepFace.extract_faces(overlay_img_path)
overlay_img = face_objs[0]["face"][:, :, ::-1] * 255
overlay_img = cv2.resize(overlay_img, (112, 112))
raw_img = img.copy()
demographies = DeepFace.analyze(img_path=img_path, actions=("age", "gender", "emotion"))

View File

@ -1,8 +1,12 @@
# 3rd party dependencies
import cv2
from deepface import DeepFace
from deepface.commons.logger import Logger
logger = Logger("tests/test_analyze.py")
# project dependencies
from deepface import DeepFace
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
detectors = ["opencv", "mtcnn"]

View File

@ -1,10 +1,12 @@
import unittest
from deepface.commons.logger import Logger
from deepface.api.src.app import create_app
# built-in dependencies
import base64
import unittest
# project dependencies
from deepface.api.src.app import create_app
from deepface.commons import logger as log
logger = Logger("tests/test_api.py")
logger = log.get_singletonish_logger()
class TestVerifyEndpoint(unittest.TestCase):

View File

@ -1,9 +1,12 @@
# 3rd party dependencies
import pytest
import numpy as np
from deepface import DeepFace
from deepface.commons.logger import Logger
logger = Logger("tests/test_enforce_detection.py")
# project dependencies
from deepface import DeepFace
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
def test_enabled_enforce_detection_for_non_facial_input():

View File

@ -1,9 +1,16 @@
# built-in dependencies
import base64
# 3rd party dependencies
import numpy as np
import pytest
from deepface import DeepFace
from deepface.commons.logger import Logger
logger = Logger("tests/test_extract_faces.py")
# project dependencies
from deepface import DeepFace
from deepface.commons import image_utils
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
detectors = ["opencv", "mtcnn"]
@ -48,3 +55,24 @@ def test_backends_for_not_enforced_detection_with_non_facial_inputs():
)
assert objs[0]["face"].shape == (224, 224, 3)
logger.info("✅ extract_faces for not enforced detection and non-facial image test is done")
def test_file_types_while_loading_base64():
img1_path = "dataset/img47.jpg"
img1_base64 = image_to_base64(image_path=img1_path)
with pytest.raises(ValueError, match="input image can be jpg or png, but it is"):
_ = image_utils.load_image_from_base64(uri=img1_base64)
img2_path = "dataset/img1.jpg"
img2_base64 = image_to_base64(image_path=img2_path)
img2 = image_utils.load_image_from_base64(uri=img2_base64)
# 3 dimensional image should be loaded
assert len(img2.shape) == 3
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return "data:image/jpeg," + encoded_string

View File

@ -1,17 +1,24 @@
# built-in dependencies
import os
# 3rd party dependencies
import cv2
import pandas as pd
# project dependencies
from deepface import DeepFace
from deepface.modules import verification
from deepface.commons.logger import Logger
from deepface.commons import image_utils
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
logger = Logger("tests/test_find.py")
threshold = verification.find_threshold(model_name="VGG-Face", distance_metric="cosine")
def test_find_with_exact_path():
img_path = os.path.join("dataset","img1.jpg")
img_path = os.path.join("dataset", "img1.jpg")
dfs = DeepFace.find(img_path=img_path, db_path="dataset", silent=True)
assert len(dfs) > 0
for df in dfs:
@ -31,7 +38,7 @@ def test_find_with_exact_path():
def test_find_with_array_input():
img_path = os.path.join("dataset","img1.jpg")
img_path = os.path.join("dataset", "img1.jpg")
img1 = cv2.imread(img_path)
dfs = DeepFace.find(img1, db_path="dataset", silent=True)
assert len(dfs) > 0
@ -53,7 +60,7 @@ def test_find_with_array_input():
def test_find_with_extracted_faces():
img_path = os.path.join("dataset","img1.jpg")
img_path = os.path.join("dataset", "img1.jpg")
face_objs = DeepFace.extract_faces(img_path)
img = face_objs[0]["face"]
dfs = DeepFace.find(img, db_path="dataset", detector_backend="skip", silent=True)
@ -72,3 +79,25 @@ def test_find_with_extracted_faces():
logger.debug(df.head())
assert df.shape[0] > 0
logger.info("✅ test find for extracted face input done")
def test_filetype_for_find():
"""
only images as jpg and png can be loaded into database
"""
img_path = os.path.join("dataset", "img1.jpg")
dfs = DeepFace.find(img_path=img_path, db_path="dataset", silent=True)
df = dfs[0]
# img47 is webp even though its extension is jpg
assert df[df["identity"] == "dataset/img47.jpg"].shape[0] == 0
def test_filetype_for_find_bulk_embeddings():
imgs = image_utils.list_images("dataset")
assert len(imgs) > 0
# img47 is webp even though its extension is jpg
assert "dataset/img47.jpg" not in imgs

View File

@ -1,9 +1,11 @@
# built-in dependencies
import cv2
# project dependencies
from deepface import DeepFace
from deepface.commons.logger import Logger
logger = Logger("tests/test_represent.py")
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
def test_standard_represent():
img_path = "dataset/img1.jpg"

View File

@ -1,10 +1,12 @@
# 3rd party dependencies
import pytest
import cv2
# project dependencies
from deepface import DeepFace
from deepface.commons.logger import Logger
logger = Logger("tests/test_facial_recognition_models.py")
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
models = ["VGG-Face", "Facenet", "Facenet512", "ArcFace", "GhostFaceNet"]
metrics = ["cosine", "euclidean", "euclidean_l2"]

View File

@ -1,8 +1,11 @@
# built-in dependencies
import json
from deepface import DeepFace
from deepface.commons.logger import Logger
logger = Logger("tests/test_version.py")
# project dependencies
from deepface import DeepFace
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
def test_version():

View File

@ -1,8 +1,11 @@
# 3rd party dependencies
import matplotlib.pyplot as plt
from deepface import DeepFace
from deepface.commons.logger import Logger
logger = Logger()
# project dependencies
from deepface import DeepFace
from deepface.commons import logger as log
logger = log.get_singletonish_logger()
# some models (e.g. Dlib) and detectors (e.g. retinaface) do not have test cases
# because they require to install huge packages
@ -18,6 +21,7 @@ model_names = [
"Dlib",
"ArcFace",
"SFace",
"GhostFaceNet",
]
detector_backends = [
@ -30,9 +34,9 @@ detector_backends = [
"retinaface",
"yunet",
"yolov8",
"centerface",
]
# verification
for model_name in model_names:
obj = DeepFace.verify(
@ -56,7 +60,6 @@ dfs = DeepFace.find(
for df in dfs:
logger.info(df)
expand_areas = [0]
img_paths = ["dataset/img11.jpg", "dataset/img11_reflection.jpg"]
for expand_area in expand_areas:
@ -71,7 +74,7 @@ for expand_area in expand_areas:
)
for face_obj in face_objs:
face = face_obj["face"]
logger.info(detector_backend)
logger.info(f"testing {img_path} with {detector_backend}")
logger.info(face_obj["facial_area"])
logger.info(face_obj["confidence"])
@ -95,7 +98,10 @@ for expand_area in expand_areas:
le_x = face_obj["facial_area"]["left_eye"][0]
assert re_x < le_x, "right eye must be the right eye of the person"
assert isinstance(face_obj["confidence"], float)
type_conf = type(face_obj["confidence"])
assert isinstance(
face_obj["confidence"], float
), f"confidence type must be float but it is {type_conf}"
assert face_obj["confidence"] <= 1
plt.imshow(face)