mirror of
https://github.com/serengil/deepface.git
synced 2025-06-09 12:57:08 +00:00
Merge pull request #1072 from serengil/feat-task-0702-find-enhancements
Feat task 0702 find enhancements
This commit is contained in:
commit
644fc67e9e
@ -312,9 +312,11 @@ $ deepface analyze -img_path tests/dataset/img1.jpg
|
|||||||
|
|
||||||
You can also run these commands if you are running deepface with docker. Please follow the instructions in the [shell script](https://github.com/serengil/deepface/blob/master/scripts/dockerize.sh#L17).
|
You can also run these commands if you are running deepface with docker. Please follow the instructions in the [shell script](https://github.com/serengil/deepface/blob/master/scripts/dockerize.sh#L17).
|
||||||
|
|
||||||
## Contribution [](https://github.com/serengil/deepface/actions/workflows/tests.yml)
|
## Contribution
|
||||||
|
|
||||||
Pull requests are more than welcome! You should run the unit tests and linting locally by running `make test && make lint` before creating a PR. Once a PR sent, GitHub test workflow will be run automatically and unit test results will be available in [GitHub actions](https://github.com/serengil/deepface/actions) before approval. Besides, workflow will evaluate the code with pylint as well.
|
Pull requests are more than welcome! If you are planning to contribute a large patch, please create an issue first to get any upfront questions or design decisions out of the way first.
|
||||||
|
|
||||||
|
Before creating a PR, you should run the unit tests and linting locally by running `make test && make lint` command. Once a PR sent, GitHub test workflow will be run automatically and unit test and linting jobs will be available in [GitHub actions](https://github.com/serengil/deepface/actions) before approval.
|
||||||
|
|
||||||
## Support
|
## Support
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ def verify(
|
|||||||
align: bool = True,
|
align: bool = True,
|
||||||
expand_percentage: int = 0,
|
expand_percentage: int = 0,
|
||||||
normalization: str = "base",
|
normalization: str = "base",
|
||||||
|
silent: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Verify if an image pair represents the same person or different persons.
|
Verify if an image pair represents the same person or different persons.
|
||||||
@ -91,6 +92,9 @@ def verify(
|
|||||||
normalization (string): Normalize the input image before feeding it to the model.
|
normalization (string): Normalize the input image before feeding it to the model.
|
||||||
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base)
|
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base)
|
||||||
|
|
||||||
|
silent (boolean): Suppress or allow some log messages for a quieter analysis process
|
||||||
|
(default is False).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result (dict): A dictionary containing verification results with following keys.
|
result (dict): A dictionary containing verification results with following keys.
|
||||||
|
|
||||||
@ -126,6 +130,7 @@ def verify(
|
|||||||
align=align,
|
align=align,
|
||||||
expand_percentage=expand_percentage,
|
expand_percentage=expand_percentage,
|
||||||
normalization=normalization,
|
normalization=normalization,
|
||||||
|
silent=silent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# built-in dependencies
|
||||||
|
import hashlib
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -14,3 +17,16 @@ def get_tf_major_version() -> int:
|
|||||||
major_version (int)
|
major_version (int)
|
||||||
"""
|
"""
|
||||||
return int(tf.__version__.split(".", maxsplit=1)[0])
|
return int(tf.__version__.split(".", maxsplit=1)[0])
|
||||||
|
|
||||||
|
|
||||||
|
def find_hash_of_file(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Find hash of image file
|
||||||
|
Args:
|
||||||
|
file_path (str): exact image path
|
||||||
|
Returns:
|
||||||
|
hash (str): digest with sha1 algorithm
|
||||||
|
"""
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
digest = hashlib.sha1(f.read()).hexdigest()
|
||||||
|
return digest
|
||||||
|
@ -34,7 +34,7 @@ def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
|
|||||||
return load_base64(img), "base64 encoded string"
|
return load_base64(img), "base64 encoded string"
|
||||||
|
|
||||||
# The image is a url
|
# The image is a url
|
||||||
if img.startswith("http://") or img.startswith("https://"):
|
if img.lower().startswith("http://") or img.lower().startswith("https://"):
|
||||||
return load_image_from_web(url=img), img
|
return load_image_from_web(url=img), img
|
||||||
|
|
||||||
# The image is a path
|
# The image is a path
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# built-in dependencies
|
# built-in dependencies
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import List, Union, Optional
|
from typing import List, Union, Optional, Dict, Any
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
@ -11,6 +11,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
|
from deepface.commons import package_utils
|
||||||
from deepface.modules import representation, detection, modeling, verification
|
from deepface.modules import representation, detection, modeling, verification
|
||||||
from deepface.models.FacialRecognition import FacialRecognition
|
from deepface.models.FacialRecognition import FacialRecognition
|
||||||
|
|
||||||
@ -97,14 +98,16 @@ def find(
|
|||||||
|
|
||||||
# ---------------------------------------
|
# ---------------------------------------
|
||||||
|
|
||||||
file_name = f"representations_{model_name}.pkl"
|
file_name = f"ds_{model_name}_{detector_backend}_v2.pkl"
|
||||||
file_name = file_name.replace("-", "_").lower()
|
file_name = file_name.replace("-", "").lower()
|
||||||
datastore_path = os.path.join(db_path, file_name)
|
datastore_path = os.path.join(db_path, file_name)
|
||||||
representations = []
|
representations = []
|
||||||
|
|
||||||
|
# required columns for representations
|
||||||
df_cols = [
|
df_cols = [
|
||||||
"identity",
|
"identity",
|
||||||
f"{model_name}_representation",
|
"hash",
|
||||||
|
"embedding",
|
||||||
"target_x",
|
"target_x",
|
||||||
"target_y",
|
"target_y",
|
||||||
"target_w",
|
"target_w",
|
||||||
@ -120,31 +123,55 @@ def find(
|
|||||||
with open(datastore_path, "rb") as f:
|
with open(datastore_path, "rb") as f:
|
||||||
representations = pickle.load(f)
|
representations = pickle.load(f)
|
||||||
|
|
||||||
# Check if the representations are out-of-date
|
# check each item of representations list has required keys
|
||||||
if len(representations) > 0:
|
for i, current_representation in enumerate(representations):
|
||||||
if len(representations[0]) != len(df_cols):
|
missing_keys = list(set(df_cols) - set(current_representation.keys()))
|
||||||
|
if len(missing_keys) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Seems existing {datastore_path} is out-of-the-date."
|
f"{i}-th item does not have some required keys - {missing_keys}."
|
||||||
"Please delete it and re-run."
|
f"Consider to delete {datastore_path}"
|
||||||
)
|
)
|
||||||
pickled_images = [representation[0] for representation in representations]
|
|
||||||
else:
|
# embedded images
|
||||||
pickled_images = []
|
pickled_images = [representation["identity"] for representation in representations]
|
||||||
|
|
||||||
# Get the list of images on storage
|
# Get the list of images on storage
|
||||||
storage_images = __list_images(path=db_path)
|
storage_images = __list_images(path=db_path)
|
||||||
|
|
||||||
|
if len(storage_images) == 0:
|
||||||
|
raise ValueError(f"No item found in {db_path}")
|
||||||
|
|
||||||
# Enforce data consistency amongst on disk images and pickle file
|
# Enforce data consistency amongst on disk images and pickle file
|
||||||
must_save_pickle = False
|
must_save_pickle = False
|
||||||
new_images = list(set(storage_images) - set(pickled_images)) # images added to storage
|
new_images = list(set(storage_images) - set(pickled_images)) # images added to storage
|
||||||
old_images = list(set(pickled_images) - set(storage_images)) # images removed from storage
|
old_images = list(set(pickled_images) - set(storage_images)) # images removed from storage
|
||||||
|
|
||||||
if not silent and (len(new_images) > 0 or len(old_images) > 0):
|
# detect replaced images
|
||||||
logger.info(f"Found {len(new_images)} new images and {len(old_images)} removed images")
|
replaced_images = []
|
||||||
|
for current_representation in representations:
|
||||||
|
identity = current_representation["identity"]
|
||||||
|
if identity in old_images:
|
||||||
|
continue
|
||||||
|
alpha_hash = current_representation["hash"]
|
||||||
|
beta_hash = package_utils.find_hash_of_file(identity)
|
||||||
|
if alpha_hash != beta_hash:
|
||||||
|
logger.debug(f"Even though {identity} represented before, it's replaced later.")
|
||||||
|
replaced_images.append(identity)
|
||||||
|
|
||||||
|
if not silent and (len(new_images) > 0 or len(old_images) > 0 or len(replaced_images) > 0):
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(new_images)} newly added image(s)"
|
||||||
|
f", {len(old_images)} removed image(s)"
|
||||||
|
f", {len(replaced_images)} replaced image(s)."
|
||||||
|
)
|
||||||
|
|
||||||
|
# append replaced images into both old and new images. these will be dropped and re-added.
|
||||||
|
new_images = new_images + replaced_images
|
||||||
|
old_images = old_images + replaced_images
|
||||||
|
|
||||||
# remove old images first
|
# remove old images first
|
||||||
if len(old_images) > 0:
|
if len(old_images) > 0:
|
||||||
representations = [rep for rep in representations if rep[0] not in old_images]
|
representations = [rep for rep in representations if rep["identity"] not in old_images]
|
||||||
must_save_pickle = True
|
must_save_pickle = True
|
||||||
|
|
||||||
# find representations for new images
|
# find representations for new images
|
||||||
@ -176,10 +203,10 @@ def find(
|
|||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# now, we got representations for facial database
|
# now, we got representations for facial database
|
||||||
df = pd.DataFrame(
|
df = pd.DataFrame(representations)
|
||||||
representations,
|
|
||||||
columns=df_cols,
|
if silent is False:
|
||||||
)
|
logger.info(f"Searching {img_path} in {df.shape[0]} length datastore")
|
||||||
|
|
||||||
# img path might have more than once face
|
# img path might have more than once face
|
||||||
source_objs = detection.extract_faces(
|
source_objs = detection.extract_faces(
|
||||||
@ -216,7 +243,7 @@ def find(
|
|||||||
|
|
||||||
distances = []
|
distances = []
|
||||||
for _, instance in df.iterrows():
|
for _, instance in df.iterrows():
|
||||||
source_representation = instance[f"{model_name}_representation"]
|
source_representation = instance["embedding"]
|
||||||
if source_representation is None:
|
if source_representation is None:
|
||||||
distances.append(float("inf")) # no representation for this image
|
distances.append(float("inf")) # no representation for this image
|
||||||
continue
|
continue
|
||||||
@ -230,21 +257,9 @@ def find(
|
|||||||
+ " after pickle created. Delete the {file_name} and re-run."
|
+ " after pickle created. Delete the {file_name} and re-run."
|
||||||
)
|
)
|
||||||
|
|
||||||
if distance_metric == "cosine":
|
distance = verification.find_distance(
|
||||||
distance = verification.find_cosine_distance(
|
source_representation, target_representation, distance_metric
|
||||||
source_representation, target_representation
|
|
||||||
)
|
)
|
||||||
elif distance_metric == "euclidean":
|
|
||||||
distance = verification.find_euclidean_distance(
|
|
||||||
source_representation, target_representation
|
|
||||||
)
|
|
||||||
elif distance_metric == "euclidean_l2":
|
|
||||||
distance = verification.find_euclidean_distance(
|
|
||||||
verification.l2_normalize(source_representation),
|
|
||||||
verification.l2_normalize(target_representation),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"invalid distance metric passes - {distance_metric}")
|
|
||||||
|
|
||||||
distances.append(distance)
|
distances.append(distance)
|
||||||
|
|
||||||
@ -254,7 +269,7 @@ def find(
|
|||||||
result_df["threshold"] = target_threshold
|
result_df["threshold"] = target_threshold
|
||||||
result_df["distance"] = distances
|
result_df["distance"] = distances
|
||||||
|
|
||||||
result_df = result_df.drop(columns=[f"{model_name}_representation"])
|
result_df = result_df.drop(columns=["embedding"])
|
||||||
# pylint: disable=unsubscriptable-object
|
# pylint: disable=unsubscriptable-object
|
||||||
result_df = result_df[result_df["distance"] <= target_threshold]
|
result_df = result_df[result_df["distance"] <= target_threshold]
|
||||||
result_df = result_df.sort_values(by=["distance"], ascending=True).reset_index(drop=True)
|
result_df = result_df.sort_values(by=["distance"], ascending=True).reset_index(drop=True)
|
||||||
@ -297,7 +312,7 @@ def __find_bulk_embeddings(
|
|||||||
expand_percentage: int = 0,
|
expand_percentage: int = 0,
|
||||||
normalization: str = "base",
|
normalization: str = "base",
|
||||||
silent: bool = False,
|
silent: bool = False,
|
||||||
):
|
) -> List[Dict["str", Any]]:
|
||||||
"""
|
"""
|
||||||
Find embeddings of a list of images
|
Find embeddings of a list of images
|
||||||
|
|
||||||
@ -323,8 +338,8 @@ def __find_bulk_embeddings(
|
|||||||
|
|
||||||
silent (bool): enable or disable informative logging
|
silent (bool): enable or disable informative logging
|
||||||
Returns:
|
Returns:
|
||||||
representations (list): pivot list of embeddings with
|
representations (list): pivot list of dict with
|
||||||
image name and detected face area's coordinates
|
image name, hash, embedding and detected face area's coordinates
|
||||||
"""
|
"""
|
||||||
representations = []
|
representations = []
|
||||||
for employee in tqdm(
|
for employee in tqdm(
|
||||||
@ -332,6 +347,8 @@ def __find_bulk_embeddings(
|
|||||||
desc="Finding representations",
|
desc="Finding representations",
|
||||||
disable=silent,
|
disable=silent,
|
||||||
):
|
):
|
||||||
|
file_hash = package_utils.find_hash_of_file(employee)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img_objs = detection.extract_faces(
|
img_objs = detection.extract_faces(
|
||||||
img_path=employee,
|
img_path=employee,
|
||||||
@ -342,15 +359,23 @@ def __find_bulk_embeddings(
|
|||||||
align=align,
|
align=align,
|
||||||
expand_percentage=expand_percentage,
|
expand_percentage=expand_percentage,
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
logger.error(
|
logger.error(f"Exception while extracting faces from {employee}: {str(err)}")
|
||||||
f"Exception while extracting faces from {employee}: {str(err)}"
|
|
||||||
)
|
|
||||||
img_objs = []
|
img_objs = []
|
||||||
|
|
||||||
if len(img_objs) == 0:
|
if len(img_objs) == 0:
|
||||||
logger.warn(f"No face detected in {employee}. It will be skipped in detection.")
|
representations.append(
|
||||||
representations.append((employee, None, 0, 0, 0, 0))
|
{
|
||||||
|
"identity": employee,
|
||||||
|
"hash": file_hash,
|
||||||
|
"embedding": None,
|
||||||
|
"target_x": 0,
|
||||||
|
"target_y": 0,
|
||||||
|
"target_w": 0,
|
||||||
|
"target_h": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for img_obj in img_objs:
|
for img_obj in img_objs:
|
||||||
img_content = img_obj["face"]
|
img_content = img_obj["face"]
|
||||||
@ -365,13 +390,16 @@ def __find_bulk_embeddings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
img_representation = embedding_obj[0]["embedding"]
|
img_representation = embedding_obj[0]["embedding"]
|
||||||
representations.append((
|
representations.append(
|
||||||
employee,
|
{
|
||||||
img_representation,
|
"identity": employee,
|
||||||
img_region["x"],
|
"hash": file_hash,
|
||||||
img_region["y"],
|
"embedding": img_representation,
|
||||||
img_region["w"],
|
"target_x": img_region["x"],
|
||||||
img_region["h"]
|
"target_y": img_region["y"],
|
||||||
))
|
"target_w": img_region["w"],
|
||||||
|
"target_h": img_region["h"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return representations
|
return representations
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# built-in dependencies
|
# built-in dependencies
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Union, List, Tuple
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,11 +8,14 @@ import numpy as np
|
|||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface.modules import representation, detection, modeling
|
from deepface.modules import representation, detection, modeling
|
||||||
from deepface.models.FacialRecognition import FacialRecognition
|
from deepface.models.FacialRecognition import FacialRecognition
|
||||||
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
|
logger = Logger(module="deepface/modules/verification.py")
|
||||||
|
|
||||||
|
|
||||||
def verify(
|
def verify(
|
||||||
img1_path: Union[str, np.ndarray],
|
img1_path: Union[str, np.ndarray, List[float]],
|
||||||
img2_path: Union[str, np.ndarray],
|
img2_path: Union[str, np.ndarray, List[float]],
|
||||||
model_name: str = "VGG-Face",
|
model_name: str = "VGG-Face",
|
||||||
detector_backend: str = "opencv",
|
detector_backend: str = "opencv",
|
||||||
distance_metric: str = "cosine",
|
distance_metric: str = "cosine",
|
||||||
@ -20,6 +23,7 @@ def verify(
|
|||||||
align: bool = True,
|
align: bool = True,
|
||||||
expand_percentage: int = 0,
|
expand_percentage: int = 0,
|
||||||
normalization: str = "base",
|
normalization: str = "base",
|
||||||
|
silent: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Verify if an image pair represents the same person or different persons.
|
Verify if an image pair represents the same person or different persons.
|
||||||
@ -30,10 +34,10 @@ def verify(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
img1_path (str or np.ndarray): Path to the first image. Accepts exact image path
|
img1_path (str or np.ndarray): Path to the first image. Accepts exact image path
|
||||||
as a string, numpy array (BGR), or base64 encoded images.
|
as a string, numpy array (BGR), base64 encoded images or pre-calculated embeddings.
|
||||||
|
|
||||||
img2_path (str or np.ndarray): Path to the second image. Accepts exact image path
|
img2_path (str or np.ndarray): Path to the second image. Accepts exact image path
|
||||||
as a string, numpy array (BGR), or base64 encoded images.
|
as a string, numpy array (BGR), base64 encoded images or pre-calculated embeddings.
|
||||||
|
|
||||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
|
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
|
||||||
@ -54,6 +58,9 @@ def verify(
|
|||||||
normalization (string): Normalize the input image before feeding it to the model.
|
normalization (string): Normalize the input image before feeding it to the model.
|
||||||
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base)
|
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base)
|
||||||
|
|
||||||
|
silent (boolean): Suppress or allow some log messages for a quieter analysis process
|
||||||
|
(default is False).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result (dict): A dictionary containing verification results.
|
result (dict): A dictionary containing verification results.
|
||||||
|
|
||||||
@ -81,83 +88,96 @@ def verify(
|
|||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
# --------------------------------
|
|
||||||
model: FacialRecognition = modeling.build_model(model_name)
|
model: FacialRecognition = modeling.build_model(model_name)
|
||||||
target_size = model.input_shape
|
dims = model.output_shape
|
||||||
|
|
||||||
try:
|
if isinstance(img1_path, list):
|
||||||
img1_objs = detection.extract_faces(
|
# given image is already pre-calculated embedding
|
||||||
|
if not all(isinstance(dim, float) for dim in img1_path):
|
||||||
|
raise ValueError(
|
||||||
|
"When passing img1_path as a list, ensure that all its items are of type float."
|
||||||
|
)
|
||||||
|
|
||||||
|
if silent is False:
|
||||||
|
logger.warn(
|
||||||
|
"You passed 1st image as pre-calculated embeddings."
|
||||||
|
f"Please ensure that embeddings have been calculated for the {model_name} model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(img1_path) != dims:
|
||||||
|
raise ValueError(
|
||||||
|
f"embeddings of {model_name} should have {dims} dimensions,"
|
||||||
|
f" but it has {len(img1_path)} dimensions input"
|
||||||
|
)
|
||||||
|
|
||||||
|
img1_embeddings = [img1_path]
|
||||||
|
img1_facial_areas = [None]
|
||||||
|
else:
|
||||||
|
img1_embeddings, img1_facial_areas = __extract_faces_and_embeddings(
|
||||||
img_path=img1_path,
|
img_path=img1_path,
|
||||||
target_size=target_size,
|
model_name=model_name,
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
grayscale=False,
|
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
align=align,
|
align=align,
|
||||||
expand_percentage=expand_percentage,
|
expand_percentage=expand_percentage,
|
||||||
|
normalization=normalization,
|
||||||
)
|
)
|
||||||
except ValueError as err:
|
|
||||||
raise ValueError("Exception while processing img1_path") from err
|
|
||||||
|
|
||||||
try:
|
if isinstance(img2_path, list):
|
||||||
img2_objs = detection.extract_faces(
|
# given image is already pre-calculated embedding
|
||||||
|
if not all(isinstance(dim, float) for dim in img2_path):
|
||||||
|
raise ValueError(
|
||||||
|
"When passing img2_path as a list, ensure that all its items are of type float."
|
||||||
|
)
|
||||||
|
|
||||||
|
if silent is False:
|
||||||
|
logger.warn(
|
||||||
|
"You passed 2nd image as pre-calculated embeddings."
|
||||||
|
f"Please ensure that embeddings have been calculated for the {model_name} model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(img2_path) != dims:
|
||||||
|
raise ValueError(
|
||||||
|
f"embeddings of {model_name} should have {dims} dimensions,"
|
||||||
|
f" but it has {len(img2_path)} dimensions input"
|
||||||
|
)
|
||||||
|
|
||||||
|
img2_embeddings = [img2_path]
|
||||||
|
img2_facial_areas = [None]
|
||||||
|
else:
|
||||||
|
img2_embeddings, img2_facial_areas = __extract_faces_and_embeddings(
|
||||||
img_path=img2_path,
|
img_path=img2_path,
|
||||||
target_size=target_size,
|
model_name=model_name,
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
grayscale=False,
|
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
align=align,
|
align=align,
|
||||||
expand_percentage=expand_percentage,
|
expand_percentage=expand_percentage,
|
||||||
)
|
|
||||||
except ValueError as err:
|
|
||||||
raise ValueError("Exception while processing img2_path") from err
|
|
||||||
|
|
||||||
img1_embeddings = []
|
|
||||||
for img1_obj in img1_objs:
|
|
||||||
img1_embedding_obj = representation.represent(
|
|
||||||
img_path=img1_obj["face"],
|
|
||||||
model_name=model_name,
|
|
||||||
enforce_detection=enforce_detection,
|
|
||||||
detector_backend="skip",
|
|
||||||
align=align,
|
|
||||||
normalization=normalization,
|
normalization=normalization,
|
||||||
)
|
)
|
||||||
img1_embedding = img1_embedding_obj[0]["embedding"]
|
|
||||||
img1_embeddings.append(img1_embedding)
|
|
||||||
|
|
||||||
img2_embeddings = []
|
no_facial_area = {
|
||||||
for img2_obj in img2_objs:
|
"x": None,
|
||||||
img2_embedding_obj = representation.represent(
|
"y": None,
|
||||||
img_path=img2_obj["face"],
|
"w": None,
|
||||||
model_name=model_name,
|
"h": None,
|
||||||
enforce_detection=enforce_detection,
|
"left_eye": None,
|
||||||
detector_backend="skip",
|
"right_eye": None,
|
||||||
align=align,
|
}
|
||||||
normalization=normalization,
|
|
||||||
)
|
|
||||||
img2_embedding = img2_embedding_obj[0]["embedding"]
|
|
||||||
img2_embeddings.append(img2_embedding)
|
|
||||||
|
|
||||||
distances = []
|
distances = []
|
||||||
regions = []
|
facial_areas = []
|
||||||
for idx, img1_embedding in enumerate(img1_embeddings):
|
for idx, img1_embedding in enumerate(img1_embeddings):
|
||||||
for idy, img2_embedding in enumerate(img2_embeddings):
|
for idy, img2_embedding in enumerate(img2_embeddings):
|
||||||
if distance_metric == "cosine":
|
distance = find_distance(img1_embedding, img2_embedding, distance_metric)
|
||||||
distance = find_cosine_distance(img1_embedding, img2_embedding)
|
|
||||||
elif distance_metric == "euclidean":
|
|
||||||
distance = find_euclidean_distance(img1_embedding, img2_embedding)
|
|
||||||
elif distance_metric == "euclidean_l2":
|
|
||||||
distance = find_euclidean_distance(
|
|
||||||
l2_normalize(img1_embedding), l2_normalize(img2_embedding)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
|
||||||
distances.append(distance)
|
distances.append(distance)
|
||||||
regions.append((img1_objs[idx]["facial_area"], img2_objs[idy]["facial_area"]))
|
facial_areas.append(
|
||||||
|
(img1_facial_areas[idx] or no_facial_area, img2_facial_areas[idy] or no_facial_area)
|
||||||
|
)
|
||||||
|
|
||||||
# find the face pair with minimum distance
|
# find the face pair with minimum distance
|
||||||
threshold = find_threshold(model_name, distance_metric)
|
threshold = find_threshold(model_name, distance_metric)
|
||||||
distance = float(min(distances)) # best distance
|
distance = float(min(distances)) # best distance
|
||||||
facial_areas = regions[np.argmin(distances)]
|
facial_areas = facial_areas[np.argmin(distances)]
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
|
|
||||||
@ -175,6 +195,58 @@ def verify(
|
|||||||
return resp_obj
|
return resp_obj
|
||||||
|
|
||||||
|
|
||||||
|
def __extract_faces_and_embeddings(
|
||||||
|
img_path: Union[str, np.ndarray],
|
||||||
|
model_name: str = "VGG-Face",
|
||||||
|
detector_backend: str = "opencv",
|
||||||
|
enforce_detection: bool = True,
|
||||||
|
align: bool = True,
|
||||||
|
expand_percentage: int = 0,
|
||||||
|
normalization: str = "base",
|
||||||
|
) -> Tuple[List[List[float]], List[dict]]:
|
||||||
|
"""
|
||||||
|
Extract facial areas and find corresponding embeddings for given image
|
||||||
|
Returns:
|
||||||
|
embeddings (List[float])
|
||||||
|
facial areas (List[dict])
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
facial_areas = []
|
||||||
|
|
||||||
|
model: FacialRecognition = modeling.build_model(model_name)
|
||||||
|
target_size = model.input_shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
img_objs = detection.extract_faces(
|
||||||
|
img_path=img_path,
|
||||||
|
target_size=target_size,
|
||||||
|
detector_backend=detector_backend,
|
||||||
|
grayscale=False,
|
||||||
|
enforce_detection=enforce_detection,
|
||||||
|
align=align,
|
||||||
|
expand_percentage=expand_percentage,
|
||||||
|
)
|
||||||
|
except ValueError as err:
|
||||||
|
raise ValueError("Exception while processing img1_path") from err
|
||||||
|
|
||||||
|
# find embeddings for each face
|
||||||
|
for img_obj in img_objs:
|
||||||
|
img_embedding_obj = representation.represent(
|
||||||
|
img_path=img_obj["face"],
|
||||||
|
model_name=model_name,
|
||||||
|
enforce_detection=enforce_detection,
|
||||||
|
detector_backend="skip",
|
||||||
|
align=align,
|
||||||
|
normalization=normalization,
|
||||||
|
)
|
||||||
|
# already extracted face given, safe to access its 1st item
|
||||||
|
img_embedding = img_embedding_obj[0]["embedding"]
|
||||||
|
embeddings.append(img_embedding)
|
||||||
|
facial_areas.append(img_obj["facial_area"])
|
||||||
|
|
||||||
|
return embeddings, facial_areas
|
||||||
|
|
||||||
|
|
||||||
def find_cosine_distance(
|
def find_cosine_distance(
|
||||||
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
|
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
|
||||||
) -> np.float64:
|
) -> np.float64:
|
||||||
@ -234,6 +306,32 @@ def l2_normalize(x: Union[np.ndarray, list]) -> np.ndarray:
|
|||||||
return x / np.sqrt(np.sum(np.multiply(x, x)))
|
return x / np.sqrt(np.sum(np.multiply(x, x)))
|
||||||
|
|
||||||
|
|
||||||
|
def find_distance(
|
||||||
|
alpha_embedding: Union[np.ndarray, list],
|
||||||
|
beta_embedding: Union[np.ndarray, list],
|
||||||
|
distance_metric: str,
|
||||||
|
) -> np.float64:
|
||||||
|
"""
|
||||||
|
Wrapper to find distance between vectors according to the given distance metric
|
||||||
|
Args:
|
||||||
|
source_representation (np.ndarray or list): 1st vector
|
||||||
|
test_representation (np.ndarray or list): 2nd vector
|
||||||
|
Returns
|
||||||
|
distance (np.float64): calculated cosine distance
|
||||||
|
"""
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
distance = find_cosine_distance(alpha_embedding, beta_embedding)
|
||||||
|
elif distance_metric == "euclidean":
|
||||||
|
distance = find_euclidean_distance(alpha_embedding, beta_embedding)
|
||||||
|
elif distance_metric == "euclidean_l2":
|
||||||
|
distance = find_euclidean_distance(
|
||||||
|
l2_normalize(alpha_embedding), l2_normalize(beta_embedding)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
||||||
|
return distance
|
||||||
|
|
||||||
|
|
||||||
def find_threshold(model_name: str, distance_metric: str) -> float:
|
def find_threshold(model_name: str, distance_metric: str) -> float:
|
||||||
"""
|
"""
|
||||||
Retrieve pre-tuned threshold values for a model and distance metric pair
|
Retrieve pre-tuned threshold values for a model and distance metric pair
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import pytest
|
||||||
import cv2
|
import cv2
|
||||||
from deepface import DeepFace
|
from deepface import DeepFace
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
@ -100,3 +101,53 @@ def test_verify_for_preloaded_image():
|
|||||||
res = DeepFace.verify(img1, img2)
|
res = DeepFace.verify(img1, img2)
|
||||||
assert res["verified"] is True
|
assert res["verified"] is True
|
||||||
logger.info("✅ test verify for pre-loaded image done")
|
logger.info("✅ test verify for pre-loaded image done")
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_for_precalculated_embeddings():
|
||||||
|
model_name = "Facenet"
|
||||||
|
|
||||||
|
img1_path = "dataset/img1.jpg"
|
||||||
|
img2_path = "dataset/img2.jpg"
|
||||||
|
|
||||||
|
img1_embedding = DeepFace.represent(img_path=img1_path, model_name=model_name)[0]["embedding"]
|
||||||
|
img2_embedding = DeepFace.represent(img_path=img2_path, model_name=model_name)[0]["embedding"]
|
||||||
|
|
||||||
|
result = DeepFace.verify(
|
||||||
|
img1_path=img1_embedding, img2_path=img2_embedding, model_name=model_name, silent=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["verified"] is True
|
||||||
|
assert result["distance"] < result["threshold"]
|
||||||
|
assert result["model"] == model_name
|
||||||
|
|
||||||
|
logger.info("✅ test verify for pre-calculated embeddings done")
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_with_precalculated_embeddings_for_incorrect_model():
|
||||||
|
# generate embeddings with VGG (default)
|
||||||
|
img1_path = "dataset/img1.jpg"
|
||||||
|
img2_path = "dataset/img2.jpg"
|
||||||
|
img1_embedding = DeepFace.represent(img_path=img1_path)[0]["embedding"]
|
||||||
|
img2_embedding = DeepFace.represent(img_path=img2_path)[0]["embedding"]
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="embeddings of Facenet should have 128 dimensions, but it has 4096 dimensions input",
|
||||||
|
):
|
||||||
|
_ = DeepFace.verify(
|
||||||
|
img1_path=img1_embedding, img2_path=img2_embedding, model_name="Facenet", silent=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("✅ test verify with pre-calculated embeddings for incorrect model done")
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_for_broken_embeddings():
|
||||||
|
img1_embeddings = ["a", "b", "c"]
|
||||||
|
img2_embeddings = [1, 2, 3]
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="When passing img1_path as a list, ensure that all its items are of type float.",
|
||||||
|
):
|
||||||
|
_ = DeepFace.verify(img1_path=img1_embeddings, img2_path=img2_embeddings)
|
||||||
|
logger.info("✅ test verify for broken embeddings content is done")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user