Merge pull request #1072 from serengil/feat-task-0702-find-enhancements

Feat task 0702 find enhancements
This commit is contained in:
Sefik Ilkin Serengil 2024-03-08 20:01:47 +00:00 committed by GitHub
commit 644fc67e9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 317 additions and 117 deletions

View File

@ -312,9 +312,11 @@ $ deepface analyze -img_path tests/dataset/img1.jpg
You can also run these commands if you are running deepface with docker. Please follow the instructions in the [shell script](https://github.com/serengil/deepface/blob/master/scripts/dockerize.sh#L17). You can also run these commands if you are running deepface with docker. Please follow the instructions in the [shell script](https://github.com/serengil/deepface/blob/master/scripts/dockerize.sh#L17).
## Contribution [![Tests](https://github.com/serengil/deepface/actions/workflows/tests.yml/badge.svg)](https://github.com/serengil/deepface/actions/workflows/tests.yml) ## Contribution
Pull requests are more than welcome! You should run the unit tests and linting locally by running `make test && make lint` before creating a PR. Once a PR sent, GitHub test workflow will be run automatically and unit test results will be available in [GitHub actions](https://github.com/serengil/deepface/actions) before approval. Besides, workflow will evaluate the code with pylint as well. Pull requests are more than welcome! If you are planning to contribute a large patch, please create an issue first to get any upfront questions or design decisions out of the way first.
Before creating a PR, you should run the unit tests and linting locally by running `make test && make lint` command. Once a PR sent, GitHub test workflow will be run automatically and unit test and linting jobs will be available in [GitHub actions](https://github.com/serengil/deepface/actions) before approval.
## Support ## Support

View File

@ -62,6 +62,7 @@ def verify(
align: bool = True, align: bool = True,
expand_percentage: int = 0, expand_percentage: int = 0,
normalization: str = "base", normalization: str = "base",
silent: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Verify if an image pair represents the same person or different persons. Verify if an image pair represents the same person or different persons.
@ -91,6 +92,9 @@ def verify(
normalization (string): Normalize the input image before feeding it to the model. normalization (string): Normalize the input image before feeding it to the model.
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base) Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base)
silent (boolean): Suppress or allow some log messages for a quieter analysis process
(default is False).
Returns: Returns:
result (dict): A dictionary containing verification results with following keys. result (dict): A dictionary containing verification results with following keys.
@ -126,6 +130,7 @@ def verify(
align=align, align=align,
expand_percentage=expand_percentage, expand_percentage=expand_percentage,
normalization=normalization, normalization=normalization,
silent=silent,
) )

View File

@ -1,3 +1,6 @@
# built-in dependencies
import hashlib
# 3rd party dependencies # 3rd party dependencies
import tensorflow as tf import tensorflow as tf
@ -14,3 +17,16 @@ def get_tf_major_version() -> int:
major_version (int) major_version (int)
""" """
return int(tf.__version__.split(".", maxsplit=1)[0]) return int(tf.__version__.split(".", maxsplit=1)[0])
def find_hash_of_file(file_path: str) -> str:
"""
Find hash of image file
Args:
file_path (str): exact image path
Returns:
hash (str): digest with sha1 algorithm
"""
with open(file_path, "rb") as f:
digest = hashlib.sha1(f.read()).hexdigest()
return digest

View File

@ -34,7 +34,7 @@ def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
return load_base64(img), "base64 encoded string" return load_base64(img), "base64 encoded string"
# The image is a url # The image is a url
if img.startswith("http://") or img.startswith("https://"): if img.lower().startswith("http://") or img.lower().startswith("https://"):
return load_image_from_web(url=img), img return load_image_from_web(url=img), img
# The image is a path # The image is a path

View File

@ -1,7 +1,7 @@
# built-in dependencies # built-in dependencies
import os import os
import pickle import pickle
from typing import List, Union, Optional from typing import List, Union, Optional, Dict, Any
import time import time
# 3rd party dependencies # 3rd party dependencies
@ -11,6 +11,7 @@ from tqdm import tqdm
# project dependencies # project dependencies
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
from deepface.commons import package_utils
from deepface.modules import representation, detection, modeling, verification from deepface.modules import representation, detection, modeling, verification
from deepface.models.FacialRecognition import FacialRecognition from deepface.models.FacialRecognition import FacialRecognition
@ -97,14 +98,16 @@ def find(
# --------------------------------------- # ---------------------------------------
file_name = f"representations_{model_name}.pkl" file_name = f"ds_{model_name}_{detector_backend}_v2.pkl"
file_name = file_name.replace("-", "_").lower() file_name = file_name.replace("-", "").lower()
datastore_path = os.path.join(db_path, file_name) datastore_path = os.path.join(db_path, file_name)
representations = [] representations = []
# required columns for representations
df_cols = [ df_cols = [
"identity", "identity",
f"{model_name}_representation", "hash",
"embedding",
"target_x", "target_x",
"target_y", "target_y",
"target_w", "target_w",
@ -120,35 +123,59 @@ def find(
with open(datastore_path, "rb") as f: with open(datastore_path, "rb") as f:
representations = pickle.load(f) representations = pickle.load(f)
# Check if the representations are out-of-date # check each item of representations list has required keys
if len(representations) > 0: for i, current_representation in enumerate(representations):
if len(representations[0]) != len(df_cols): missing_keys = list(set(df_cols) - set(current_representation.keys()))
if len(missing_keys) > 0:
raise ValueError( raise ValueError(
f"Seems existing {datastore_path} is out-of-the-date." f"{i}-th item does not have some required keys - {missing_keys}."
"Please delete it and re-run." f"Consider to delete {datastore_path}"
) )
pickled_images = [representation[0] for representation in representations]
else: # embedded images
pickled_images = [] pickled_images = [representation["identity"] for representation in representations]
# Get the list of images on storage # Get the list of images on storage
storage_images = __list_images(path=db_path) storage_images = __list_images(path=db_path)
if len(storage_images) == 0:
raise ValueError(f"No item found in {db_path}")
# Enforce data consistency amongst on disk images and pickle file # Enforce data consistency amongst on disk images and pickle file
must_save_pickle = False must_save_pickle = False
new_images = list(set(storage_images) - set(pickled_images)) # images added to storage new_images = list(set(storage_images) - set(pickled_images)) # images added to storage
old_images = list(set(pickled_images) - set(storage_images)) # images removed from storage old_images = list(set(pickled_images) - set(storage_images)) # images removed from storage
if not silent and (len(new_images) > 0 or len(old_images) > 0): # detect replaced images
logger.info(f"Found {len(new_images)} new images and {len(old_images)} removed images") replaced_images = []
for current_representation in representations:
identity = current_representation["identity"]
if identity in old_images:
continue
alpha_hash = current_representation["hash"]
beta_hash = package_utils.find_hash_of_file(identity)
if alpha_hash != beta_hash:
logger.debug(f"Even though {identity} represented before, it's replaced later.")
replaced_images.append(identity)
if not silent and (len(new_images) > 0 or len(old_images) > 0 or len(replaced_images) > 0):
logger.info(
f"Found {len(new_images)} newly added image(s)"
f", {len(old_images)} removed image(s)"
f", {len(replaced_images)} replaced image(s)."
)
# append replaced images into both old and new images. these will be dropped and re-added.
new_images = new_images + replaced_images
old_images = old_images + replaced_images
# remove old images first # remove old images first
if len(old_images)>0: if len(old_images) > 0:
representations = [rep for rep in representations if rep[0] not in old_images] representations = [rep for rep in representations if rep["identity"] not in old_images]
must_save_pickle = True must_save_pickle = True
# find representations for new images # find representations for new images
if len(new_images)>0: if len(new_images) > 0:
representations += __find_bulk_embeddings( representations += __find_bulk_embeddings(
employees=new_images, employees=new_images,
model_name=model_name, model_name=model_name,
@ -176,10 +203,10 @@ def find(
# ---------------------------- # ----------------------------
# now, we got representations for facial database # now, we got representations for facial database
df = pd.DataFrame( df = pd.DataFrame(representations)
representations,
columns=df_cols, if silent is False:
) logger.info(f"Searching {img_path} in {df.shape[0]} length datastore")
# img path might have more than once face # img path might have more than once face
source_objs = detection.extract_faces( source_objs = detection.extract_faces(
@ -216,7 +243,7 @@ def find(
distances = [] distances = []
for _, instance in df.iterrows(): for _, instance in df.iterrows():
source_representation = instance[f"{model_name}_representation"] source_representation = instance["embedding"]
if source_representation is None: if source_representation is None:
distances.append(float("inf")) # no representation for this image distances.append(float("inf")) # no representation for this image
continue continue
@ -230,21 +257,9 @@ def find(
+ " after pickle created. Delete the {file_name} and re-run." + " after pickle created. Delete the {file_name} and re-run."
) )
if distance_metric == "cosine": distance = verification.find_distance(
distance = verification.find_cosine_distance( source_representation, target_representation, distance_metric
source_representation, target_representation
) )
elif distance_metric == "euclidean":
distance = verification.find_euclidean_distance(
source_representation, target_representation
)
elif distance_metric == "euclidean_l2":
distance = verification.find_euclidean_distance(
verification.l2_normalize(source_representation),
verification.l2_normalize(target_representation),
)
else:
raise ValueError(f"invalid distance metric passes - {distance_metric}")
distances.append(distance) distances.append(distance)
@ -254,7 +269,7 @@ def find(
result_df["threshold"] = target_threshold result_df["threshold"] = target_threshold
result_df["distance"] = distances result_df["distance"] = distances
result_df = result_df.drop(columns=[f"{model_name}_representation"]) result_df = result_df.drop(columns=["embedding"])
# pylint: disable=unsubscriptable-object # pylint: disable=unsubscriptable-object
result_df = result_df[result_df["distance"] <= target_threshold] result_df = result_df[result_df["distance"] <= target_threshold]
result_df = result_df.sort_values(by=["distance"], ascending=True).reset_index(drop=True) result_df = result_df.sort_values(by=["distance"], ascending=True).reset_index(drop=True)
@ -297,7 +312,7 @@ def __find_bulk_embeddings(
expand_percentage: int = 0, expand_percentage: int = 0,
normalization: str = "base", normalization: str = "base",
silent: bool = False, silent: bool = False,
): ) -> List[Dict["str", Any]]:
""" """
Find embeddings of a list of images Find embeddings of a list of images
@ -323,8 +338,8 @@ def __find_bulk_embeddings(
silent (bool): enable or disable informative logging silent (bool): enable or disable informative logging
Returns: Returns:
representations (list): pivot list of embeddings with representations (list): pivot list of dict with
image name and detected face area's coordinates image name, hash, embedding and detected face area's coordinates
""" """
representations = [] representations = []
for employee in tqdm( for employee in tqdm(
@ -332,6 +347,8 @@ def __find_bulk_embeddings(
desc="Finding representations", desc="Finding representations",
disable=silent, disable=silent,
): ):
file_hash = package_utils.find_hash_of_file(employee)
try: try:
img_objs = detection.extract_faces( img_objs = detection.extract_faces(
img_path=employee, img_path=employee,
@ -342,15 +359,23 @@ def __find_bulk_embeddings(
align=align, align=align,
expand_percentage=expand_percentage, expand_percentage=expand_percentage,
) )
except ValueError as err: except ValueError as err:
logger.error( logger.error(f"Exception while extracting faces from {employee}: {str(err)}")
f"Exception while extracting faces from {employee}: {str(err)}"
)
img_objs = [] img_objs = []
if len(img_objs) == 0: if len(img_objs) == 0:
logger.warn(f"No face detected in {employee}. It will be skipped in detection.") representations.append(
representations.append((employee, None, 0, 0, 0, 0)) {
"identity": employee,
"hash": file_hash,
"embedding": None,
"target_x": 0,
"target_y": 0,
"target_w": 0,
"target_h": 0,
}
)
else: else:
for img_obj in img_objs: for img_obj in img_objs:
img_content = img_obj["face"] img_content = img_obj["face"]
@ -365,13 +390,16 @@ def __find_bulk_embeddings(
) )
img_representation = embedding_obj[0]["embedding"] img_representation = embedding_obj[0]["embedding"]
representations.append(( representations.append(
employee, {
img_representation, "identity": employee,
img_region["x"], "hash": file_hash,
img_region["y"], "embedding": img_representation,
img_region["w"], "target_x": img_region["x"],
img_region["h"] "target_y": img_region["y"],
)) "target_w": img_region["w"],
"target_h": img_region["h"],
}
)
return representations return representations

View File

@ -1,6 +1,6 @@
# built-in dependencies # built-in dependencies
import time import time
from typing import Any, Dict, Union from typing import Any, Dict, Union, List, Tuple
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -8,11 +8,14 @@ import numpy as np
# project dependencies # project dependencies
from deepface.modules import representation, detection, modeling from deepface.modules import representation, detection, modeling
from deepface.models.FacialRecognition import FacialRecognition from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
logger = Logger(module="deepface/modules/verification.py")
def verify( def verify(
img1_path: Union[str, np.ndarray], img1_path: Union[str, np.ndarray, List[float]],
img2_path: Union[str, np.ndarray], img2_path: Union[str, np.ndarray, List[float]],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
detector_backend: str = "opencv", detector_backend: str = "opencv",
distance_metric: str = "cosine", distance_metric: str = "cosine",
@ -20,6 +23,7 @@ def verify(
align: bool = True, align: bool = True,
expand_percentage: int = 0, expand_percentage: int = 0,
normalization: str = "base", normalization: str = "base",
silent: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Verify if an image pair represents the same person or different persons. Verify if an image pair represents the same person or different persons.
@ -30,10 +34,10 @@ def verify(
Args: Args:
img1_path (str or np.ndarray): Path to the first image. Accepts exact image path img1_path (str or np.ndarray): Path to the first image. Accepts exact image path
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), base64 encoded images or pre-calculated embeddings.
img2_path (str or np.ndarray): Path to the second image. Accepts exact image path img2_path (str or np.ndarray): Path to the second image. Accepts exact image path
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), base64 encoded images or pre-calculated embeddings.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace and SFace (default is VGG-Face).
@ -54,6 +58,9 @@ def verify(
normalization (string): Normalize the input image before feeding it to the model. normalization (string): Normalize the input image before feeding it to the model.
Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base) Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace (default is base)
silent (boolean): Suppress or allow some log messages for a quieter analysis process
(default is False).
Returns: Returns:
result (dict): A dictionary containing verification results. result (dict): A dictionary containing verification results.
@ -81,83 +88,96 @@ def verify(
tic = time.time() tic = time.time()
# --------------------------------
model: FacialRecognition = modeling.build_model(model_name) model: FacialRecognition = modeling.build_model(model_name)
target_size = model.input_shape dims = model.output_shape
try: if isinstance(img1_path, list):
img1_objs = detection.extract_faces( # given image is already pre-calculated embedding
if not all(isinstance(dim, float) for dim in img1_path):
raise ValueError(
"When passing img1_path as a list, ensure that all its items are of type float."
)
if silent is False:
logger.warn(
"You passed 1st image as pre-calculated embeddings."
f"Please ensure that embeddings have been calculated for the {model_name} model."
)
if len(img1_path) != dims:
raise ValueError(
f"embeddings of {model_name} should have {dims} dimensions,"
f" but it has {len(img1_path)} dimensions input"
)
img1_embeddings = [img1_path]
img1_facial_areas = [None]
else:
img1_embeddings, img1_facial_areas = __extract_faces_and_embeddings(
img_path=img1_path, img_path=img1_path,
target_size=target_size, model_name=model_name,
detector_backend=detector_backend, detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection, enforce_detection=enforce_detection,
align=align, align=align,
expand_percentage=expand_percentage, expand_percentage=expand_percentage,
normalization=normalization,
) )
except ValueError as err:
raise ValueError("Exception while processing img1_path") from err
try: if isinstance(img2_path, list):
img2_objs = detection.extract_faces( # given image is already pre-calculated embedding
if not all(isinstance(dim, float) for dim in img2_path):
raise ValueError(
"When passing img2_path as a list, ensure that all its items are of type float."
)
if silent is False:
logger.warn(
"You passed 2nd image as pre-calculated embeddings."
f"Please ensure that embeddings have been calculated for the {model_name} model."
)
if len(img2_path) != dims:
raise ValueError(
f"embeddings of {model_name} should have {dims} dimensions,"
f" but it has {len(img2_path)} dimensions input"
)
img2_embeddings = [img2_path]
img2_facial_areas = [None]
else:
img2_embeddings, img2_facial_areas = __extract_faces_and_embeddings(
img_path=img2_path, img_path=img2_path,
target_size=target_size, model_name=model_name,
detector_backend=detector_backend, detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection, enforce_detection=enforce_detection,
align=align, align=align,
expand_percentage=expand_percentage, expand_percentage=expand_percentage,
)
except ValueError as err:
raise ValueError("Exception while processing img2_path") from err
img1_embeddings = []
for img1_obj in img1_objs:
img1_embedding_obj = representation.represent(
img_path=img1_obj["face"],
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization, normalization=normalization,
) )
img1_embedding = img1_embedding_obj[0]["embedding"]
img1_embeddings.append(img1_embedding)
img2_embeddings = [] no_facial_area = {
for img2_obj in img2_objs: "x": None,
img2_embedding_obj = representation.represent( "y": None,
img_path=img2_obj["face"], "w": None,
model_name=model_name, "h": None,
enforce_detection=enforce_detection, "left_eye": None,
detector_backend="skip", "right_eye": None,
align=align, }
normalization=normalization,
)
img2_embedding = img2_embedding_obj[0]["embedding"]
img2_embeddings.append(img2_embedding)
distances = [] distances = []
regions = [] facial_areas = []
for idx, img1_embedding in enumerate(img1_embeddings): for idx, img1_embedding in enumerate(img1_embeddings):
for idy, img2_embedding in enumerate(img2_embeddings): for idy, img2_embedding in enumerate(img2_embeddings):
if distance_metric == "cosine": distance = find_distance(img1_embedding, img2_embedding, distance_metric)
distance = find_cosine_distance(img1_embedding, img2_embedding)
elif distance_metric == "euclidean":
distance = find_euclidean_distance(img1_embedding, img2_embedding)
elif distance_metric == "euclidean_l2":
distance = find_euclidean_distance(
l2_normalize(img1_embedding), l2_normalize(img2_embedding)
)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
distances.append(distance) distances.append(distance)
regions.append((img1_objs[idx]["facial_area"], img2_objs[idy]["facial_area"])) facial_areas.append(
(img1_facial_areas[idx] or no_facial_area, img2_facial_areas[idy] or no_facial_area)
)
# find the face pair with minimum distance # find the face pair with minimum distance
threshold = find_threshold(model_name, distance_metric) threshold = find_threshold(model_name, distance_metric)
distance = float(min(distances)) # best distance distance = float(min(distances)) # best distance
facial_areas = regions[np.argmin(distances)] facial_areas = facial_areas[np.argmin(distances)]
toc = time.time() toc = time.time()
@ -175,6 +195,58 @@ def verify(
return resp_obj return resp_obj
def __extract_faces_and_embeddings(
img_path: Union[str, np.ndarray],
model_name: str = "VGG-Face",
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,
expand_percentage: int = 0,
normalization: str = "base",
) -> Tuple[List[List[float]], List[dict]]:
"""
Extract facial areas and find corresponding embeddings for given image
Returns:
embeddings (List[float])
facial areas (List[dict])
"""
embeddings = []
facial_areas = []
model: FacialRecognition = modeling.build_model(model_name)
target_size = model.input_shape
try:
img_objs = detection.extract_faces(
img_path=img_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
)
except ValueError as err:
raise ValueError("Exception while processing img1_path") from err
# find embeddings for each face
for img_obj in img_objs:
img_embedding_obj = representation.represent(
img_path=img_obj["face"],
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
# already extracted face given, safe to access its 1st item
img_embedding = img_embedding_obj[0]["embedding"]
embeddings.append(img_embedding)
facial_areas.append(img_obj["facial_area"])
return embeddings, facial_areas
def find_cosine_distance( def find_cosine_distance(
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list] source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
) -> np.float64: ) -> np.float64:
@ -234,6 +306,32 @@ def l2_normalize(x: Union[np.ndarray, list]) -> np.ndarray:
return x / np.sqrt(np.sum(np.multiply(x, x))) return x / np.sqrt(np.sum(np.multiply(x, x)))
def find_distance(
alpha_embedding: Union[np.ndarray, list],
beta_embedding: Union[np.ndarray, list],
distance_metric: str,
) -> np.float64:
"""
Wrapper to find distance between vectors according to the given distance metric
Args:
source_representation (np.ndarray or list): 1st vector
test_representation (np.ndarray or list): 2nd vector
Returns
distance (np.float64): calculated cosine distance
"""
if distance_metric == "cosine":
distance = find_cosine_distance(alpha_embedding, beta_embedding)
elif distance_metric == "euclidean":
distance = find_euclidean_distance(alpha_embedding, beta_embedding)
elif distance_metric == "euclidean_l2":
distance = find_euclidean_distance(
l2_normalize(alpha_embedding), l2_normalize(beta_embedding)
)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
return distance
def find_threshold(model_name: str, distance_metric: str) -> float: def find_threshold(model_name: str, distance_metric: str) -> float:
""" """
Retrieve pre-tuned threshold values for a model and distance metric pair Retrieve pre-tuned threshold values for a model and distance metric pair

View File

@ -1,3 +1,4 @@
import pytest
import cv2 import cv2
from deepface import DeepFace from deepface import DeepFace
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
@ -100,3 +101,53 @@ def test_verify_for_preloaded_image():
res = DeepFace.verify(img1, img2) res = DeepFace.verify(img1, img2)
assert res["verified"] is True assert res["verified"] is True
logger.info("✅ test verify for pre-loaded image done") logger.info("✅ test verify for pre-loaded image done")
def test_verify_for_precalculated_embeddings():
model_name = "Facenet"
img1_path = "dataset/img1.jpg"
img2_path = "dataset/img2.jpg"
img1_embedding = DeepFace.represent(img_path=img1_path, model_name=model_name)[0]["embedding"]
img2_embedding = DeepFace.represent(img_path=img2_path, model_name=model_name)[0]["embedding"]
result = DeepFace.verify(
img1_path=img1_embedding, img2_path=img2_embedding, model_name=model_name, silent=True
)
assert result["verified"] is True
assert result["distance"] < result["threshold"]
assert result["model"] == model_name
logger.info("✅ test verify for pre-calculated embeddings done")
def test_verify_with_precalculated_embeddings_for_incorrect_model():
# generate embeddings with VGG (default)
img1_path = "dataset/img1.jpg"
img2_path = "dataset/img2.jpg"
img1_embedding = DeepFace.represent(img_path=img1_path)[0]["embedding"]
img2_embedding = DeepFace.represent(img_path=img2_path)[0]["embedding"]
with pytest.raises(
ValueError,
match="embeddings of Facenet should have 128 dimensions, but it has 4096 dimensions input",
):
_ = DeepFace.verify(
img1_path=img1_embedding, img2_path=img2_embedding, model_name="Facenet", silent=True
)
logger.info("✅ test verify with pre-calculated embeddings for incorrect model done")
def test_verify_for_broken_embeddings():
img1_embeddings = ["a", "b", "c"]
img2_embeddings = [1, 2, 3]
with pytest.raises(
ValueError,
match="When passing img1_path as a list, ensure that all its items are of type float.",
):
_ = DeepFace.verify(img1_path=img1_embeddings, img2_path=img2_embeddings)
logger.info("✅ test verify for broken embeddings content is done")