mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
Merge branch 'serengil:master' into master
This commit is contained in:
commit
476bfd1619
@ -276,7 +276,8 @@ def find(
|
||||
silent: bool = False,
|
||||
refresh_database: bool = True,
|
||||
anti_spoofing: bool = False,
|
||||
) -> List[pd.DataFrame]:
|
||||
batched: bool = False,
|
||||
) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Identify individuals in a database
|
||||
Args:
|
||||
@ -322,22 +323,32 @@ def find(
|
||||
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
|
||||
|
||||
Returns:
|
||||
results (List[pd.DataFrame]): A list of pandas dataframes. Each dataframe corresponds
|
||||
to the identity information for an individual detected in the source image.
|
||||
The DataFrame columns include:
|
||||
results (List[pd.DataFrame] or List[List[Dict[str, Any]]]):
|
||||
A list of pandas dataframes (if `batched=False`) or
|
||||
a list of dicts (if `batched=True`).
|
||||
Each dataframe or dict corresponds to the identity information for
|
||||
an individual detected in the source image.
|
||||
|
||||
- 'identity': Identity label of the detected individual.
|
||||
Note: If you have a large database and/or a source photo with many faces,
|
||||
use `batched=True`, as it is optimized for large batch processing.
|
||||
Please pay attention that when using `batched=True`, the function returns
|
||||
a list of dicts (not a list of DataFrames),
|
||||
but with the same keys as the columns in the DataFrame.
|
||||
|
||||
The DataFrame columns or dict keys include:
|
||||
|
||||
- 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the
|
||||
target face in the database.
|
||||
- 'identity': Identity label of the detected individual.
|
||||
|
||||
- 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the
|
||||
detected face in the source image.
|
||||
- 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the
|
||||
target face in the database.
|
||||
|
||||
- 'threshold': threshold to determine a pair whether same person or different persons
|
||||
- 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the
|
||||
detected face in the source image.
|
||||
|
||||
- 'distance': Similarity score between the faces based on the
|
||||
specified model and distance metric
|
||||
- 'threshold': threshold to determine a pair whether same person or different persons
|
||||
|
||||
- 'distance': Similarity score between the faces based on the
|
||||
specified model and distance metric
|
||||
"""
|
||||
return recognition.find(
|
||||
img_path=img_path,
|
||||
@ -353,6 +364,7 @@ def find(
|
||||
silent=silent,
|
||||
refresh_database=refresh_database,
|
||||
anti_spoofing=anti_spoofing,
|
||||
batched=batched
|
||||
)
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ def download_weights_if_necessary(
|
||||
"""
|
||||
home = folder_utils.get_deepface_home()
|
||||
|
||||
target_file = os.path.join(home, ".deepface/weights", file_name)
|
||||
target_file = os.path.normpath(os.path.join(home, ".deepface/weights", file_name))
|
||||
|
||||
if os.path.isfile(target_file):
|
||||
logger.debug(f"{file_name} is already available at {target_file}")
|
||||
|
@ -31,7 +31,8 @@ def find(
|
||||
silent: bool = False,
|
||||
refresh_database: bool = True,
|
||||
anti_spoofing: bool = False,
|
||||
) -> List[pd.DataFrame]:
|
||||
batched: bool = False,
|
||||
) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Identify individuals in a database
|
||||
|
||||
@ -77,9 +78,19 @@ def find(
|
||||
|
||||
|
||||
Returns:
|
||||
results (List[pd.DataFrame]): A list of pandas dataframes. Each dataframe corresponds
|
||||
to the identity information for an individual detected in the source image.
|
||||
The DataFrame columns include:
|
||||
results (List[pd.DataFrame] or List[List[Dict[str, Any]]]):
|
||||
A list of pandas dataframes (if `batched=False`) or
|
||||
a list of dicts (if `batched=True`).
|
||||
Each dataframe or dict corresponds to the identity information for
|
||||
an individual detected in the source image.
|
||||
|
||||
Note: If you have a large database and/or a source photo with many faces,
|
||||
use `batched=True`, as it is optimized for large batch processing.
|
||||
Please pay attention that when using `batched=True`, the function returns
|
||||
a list of dicts (not a list of DataFrames),
|
||||
but with the same keys as the columns in the DataFrame.
|
||||
|
||||
The DataFrame columns or dict keys include:
|
||||
|
||||
- 'identity': Identity label of the detected individual.
|
||||
|
||||
@ -233,10 +244,6 @@ def find(
|
||||
|
||||
# ----------------------------
|
||||
# now, we got representations for facial database
|
||||
df = pd.DataFrame(representations)
|
||||
|
||||
if silent is False:
|
||||
logger.info(f"Searching {img_path} in {df.shape[0]} length datastore")
|
||||
|
||||
# img path might have more than once face
|
||||
source_objs = detection.extract_faces(
|
||||
@ -249,6 +256,24 @@ def find(
|
||||
anti_spoofing=anti_spoofing,
|
||||
)
|
||||
|
||||
if batched:
|
||||
return find_batched(
|
||||
representations,
|
||||
source_objs,
|
||||
model_name,
|
||||
distance_metric,
|
||||
enforce_detection,
|
||||
align,
|
||||
threshold,
|
||||
normalization,
|
||||
anti_spoofing
|
||||
)
|
||||
|
||||
df = pd.DataFrame(representations)
|
||||
|
||||
if silent is False:
|
||||
logger.info(f"Searching {img_path} in {df.shape[0]} length datastore")
|
||||
|
||||
resp_obj = []
|
||||
|
||||
for source_obj in source_objs:
|
||||
@ -415,3 +440,219 @@ def __find_bulk_embeddings(
|
||||
)
|
||||
|
||||
return representations
|
||||
|
||||
def find_batched(
|
||||
representations: List[Dict[str, Any]],
|
||||
source_objs: List[Dict[str, Any]],
|
||||
model_name: str = "VGG-Face",
|
||||
distance_metric: str = "cosine",
|
||||
enforce_detection: bool = True,
|
||||
align: bool = True,
|
||||
threshold: Optional[float] = None,
|
||||
normalization: str = "base",
|
||||
anti_spoofing: bool = False,
|
||||
) -> List[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Perform batched face recognition by comparing source face embeddings with a set of
|
||||
target embeddings. It calculates pairwise distances between the source and target
|
||||
embeddings using the specified distance metric.
|
||||
The function uses batch processing for efficient computation of distances.
|
||||
|
||||
Args:
|
||||
representations (List[Dict[str, Any]]):
|
||||
A list of dictionaries containing precomputed target embeddings and associated metadata.
|
||||
Each dictionary should have at least the key `embedding`.
|
||||
|
||||
source_objs (List[Dict[str, Any]]):
|
||||
A list of dictionaries representing the source images to compare against
|
||||
the target embeddings. Each dictionary should contain:
|
||||
- `face`: The image data or path to the source face image.
|
||||
- `facial_area`: A dictionary with keys `x`, `y`, `w`, `h`
|
||||
indicating the facial region.
|
||||
- Optionally, `is_real`: A boolean indicating if the face is real
|
||||
(used for anti-spoofing).
|
||||
|
||||
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
|
||||
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
|
||||
|
||||
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
|
||||
'euclidean', 'euclidean_l2'.
|
||||
|
||||
enforce_detection (boolean): If no face is detected in an image, raise an exception.
|
||||
Default is True. Set to False to avoid the exception for low-resolution images.
|
||||
|
||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
|
||||
|
||||
align (boolean): Perform alignment based on the eye positions.
|
||||
|
||||
threshold (float): Specify a threshold to determine whether a pair represents the same
|
||||
person or different individuals. This threshold is used for comparing distances.
|
||||
If left unset, default pre-tuned threshold values will be applied based on the specified
|
||||
model name and distance metric (default is None).
|
||||
|
||||
normalization (string): Normalize the input image before feeding it to the model.
|
||||
Default is base. Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace
|
||||
|
||||
silent (boolean): Suppress or allow some log messages for a quieter analysis process.
|
||||
|
||||
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
|
||||
|
||||
Returns:
|
||||
List[List[Dict[str, Any]]]:
|
||||
A list where each element corresponds to a source face and
|
||||
contains a list of dictionaries with matching faces.
|
||||
"""
|
||||
embeddings_list = []
|
||||
valid_mask = []
|
||||
metadata = set()
|
||||
|
||||
for item in representations:
|
||||
emb = item.get('embedding')
|
||||
if emb is not None:
|
||||
embeddings_list.append(emb)
|
||||
valid_mask.append(True)
|
||||
else:
|
||||
embeddings_list.append(np.zeros_like(representations[0]['embedding']))
|
||||
valid_mask.append(False)
|
||||
|
||||
metadata.update(item.keys())
|
||||
|
||||
# remove embedding key from other keys
|
||||
metadata.discard('embedding')
|
||||
metadata = list(metadata)
|
||||
|
||||
embeddings = np.array(embeddings_list) # (N, D)
|
||||
valid_mask = np.array(valid_mask) # (N,)
|
||||
|
||||
data = {
|
||||
key: np.array([item.get(key, None) for item in representations])
|
||||
for key in metadata
|
||||
}
|
||||
|
||||
target_embeddings = []
|
||||
source_regions = []
|
||||
target_thresholds = []
|
||||
|
||||
for source_obj in source_objs:
|
||||
if anti_spoofing and not source_obj.get("is_real", True):
|
||||
raise ValueError("Spoof detected in the given image.")
|
||||
|
||||
source_img = source_obj["face"]
|
||||
source_region = source_obj["facial_area"]
|
||||
|
||||
target_embedding_obj = representation.represent(
|
||||
img_path=source_img,
|
||||
model_name=model_name,
|
||||
enforce_detection=enforce_detection,
|
||||
detector_backend="skip",
|
||||
align=align,
|
||||
normalization=normalization,
|
||||
)
|
||||
# it is safe to access 0 index because we already fed detected face to represent function
|
||||
target_representation = target_embedding_obj[0]["embedding"]
|
||||
|
||||
target_embeddings.append(target_representation)
|
||||
source_regions.append(source_region)
|
||||
|
||||
target_threshold = threshold or verification.find_threshold(model_name, distance_metric)
|
||||
target_thresholds.append(target_threshold)
|
||||
|
||||
target_embeddings = np.array(target_embeddings) # (M, D)
|
||||
target_thresholds = np.array(target_thresholds) # (M,)
|
||||
source_regions_arr = {
|
||||
'source_x': np.array([region['x'] for region in source_regions]),
|
||||
'source_y': np.array([region['y'] for region in source_regions]),
|
||||
'source_w': np.array([region['w'] for region in source_regions]),
|
||||
'source_h': np.array([region['h'] for region in source_regions]),
|
||||
}
|
||||
|
||||
def find_cosine_distance_batch(
|
||||
embeddings: np.ndarray, target_embeddings: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Find the cosine distances between batches of embeddings
|
||||
Args:
|
||||
embeddings (np.ndarray): array of shape (N, D)
|
||||
target_embeddings (np.ndarray): array of shape (M, D)
|
||||
Returns:
|
||||
np.ndarray: distance matrix of shape (M, N)
|
||||
"""
|
||||
embeddings_norm = verification.l2_normalize(embeddings, axis=1)
|
||||
target_embeddings_norm = verification.l2_normalize(target_embeddings, axis=1)
|
||||
cosine_similarities = np.dot(target_embeddings_norm, embeddings_norm.T)
|
||||
cosine_distances = 1 - cosine_similarities
|
||||
return cosine_distances
|
||||
|
||||
def find_euclidean_distance_batch(
|
||||
embeddings: np.ndarray, target_embeddings: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Find the Euclidean distances between batches of embeddings
|
||||
Args:
|
||||
embeddings (np.ndarray): array of shape (N, D)
|
||||
target_embeddings (np.ndarray): array of shape (M, D)
|
||||
Returns:
|
||||
np.ndarray: distance matrix of shape (M, N)
|
||||
"""
|
||||
diff = embeddings[None, :, :] - target_embeddings[:, None, :] # (M, N, D)
|
||||
distances = np.linalg.norm(diff, axis=2) # (M, N)
|
||||
return distances
|
||||
|
||||
def find_distance_batch(
|
||||
embeddings: np.ndarray, target_embeddings: np.ndarray, distance_metric: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Find pairwise distances between batches of embeddings using the specified distance metric
|
||||
Args:
|
||||
embeddings (np.ndarray): array of shape (N, D)
|
||||
target_embeddings (np.ndarray): array of shape (M, D)
|
||||
distance_metric (str): distance metric ('cosine', 'euclidean', 'euclidean_l2')
|
||||
Returns:
|
||||
np.ndarray: distance matrix of shape (M, N)
|
||||
"""
|
||||
if distance_metric == "cosine":
|
||||
distances = find_cosine_distance_batch(embeddings, target_embeddings)
|
||||
elif distance_metric == "euclidean":
|
||||
distances = find_euclidean_distance_batch(embeddings, target_embeddings)
|
||||
elif distance_metric == "euclidean_l2":
|
||||
embeddings_norm = verification.l2_normalize(embeddings, axis=1)
|
||||
target_embeddings_norm = verification.l2_normalize(target_embeddings, axis=1)
|
||||
distances = find_euclidean_distance_batch(embeddings_norm, target_embeddings_norm)
|
||||
else:
|
||||
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
||||
return np.round(distances, 6)
|
||||
|
||||
distances = find_distance_batch(embeddings, target_embeddings, distance_metric) # (M, N)
|
||||
distances[:, ~valid_mask] = np.inf
|
||||
|
||||
resp_obj = []
|
||||
|
||||
for i in range(len(target_embeddings)):
|
||||
target_distances = distances[i] # (N,)
|
||||
target_threshold = target_thresholds[i]
|
||||
|
||||
N = embeddings.shape[0]
|
||||
result_data = dict(data)
|
||||
result_data.update({
|
||||
'source_x': np.full(N, source_regions_arr['source_x'][i]),
|
||||
'source_y': np.full(N, source_regions_arr['source_y'][i]),
|
||||
'source_w': np.full(N, source_regions_arr['source_w'][i]),
|
||||
'source_h': np.full(N, source_regions_arr['source_h'][i]),
|
||||
'threshold': np.full(N, target_threshold),
|
||||
'distance': target_distances,
|
||||
})
|
||||
|
||||
mask = target_distances <= target_threshold
|
||||
filtered_data = {key: value[mask] for key, value in result_data.items()}
|
||||
|
||||
sorted_indices = np.argsort(filtered_data['distance'])
|
||||
sorted_data = {key: value[sorted_indices] for key, value in filtered_data.items()}
|
||||
|
||||
num_results = len(sorted_data['distance'])
|
||||
result_dicts = [
|
||||
{key: sorted_data[key][i] for key in sorted_data}
|
||||
for i in range(num_results)
|
||||
]
|
||||
resp_obj.append(result_dicts)
|
||||
return resp_obj
|
||||
|
@ -304,18 +304,21 @@ def find_euclidean_distance(
|
||||
return np.linalg.norm(source_representation - test_representation)
|
||||
|
||||
|
||||
def l2_normalize(x: Union[np.ndarray, list]) -> np.ndarray:
|
||||
def l2_normalize(
|
||||
x: Union[np.ndarray, list], axis: Union[int, None] = None, epsilon: float = 1e-10
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Normalize input vector with l2
|
||||
Args:
|
||||
x (np.ndarray or list): given vector
|
||||
axis (int): axis along which to normalize
|
||||
Returns:
|
||||
y (np.ndarray): l2 normalized vector
|
||||
np.ndarray: l2 normalized vector
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
x = np.array(x)
|
||||
norm = np.linalg.norm(x)
|
||||
return x if norm == 0 else x / norm
|
||||
norm = np.linalg.norm(x, axis=axis, keepdims=True)
|
||||
return x / (norm + epsilon)
|
||||
|
||||
|
||||
def find_distance(
|
||||
@ -341,7 +344,7 @@ def find_distance(
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
||||
return distance
|
||||
return np.round(distance, 6)
|
||||
|
||||
|
||||
def find_threshold(model_name: str, distance_metric: str) -> float:
|
||||
|
@ -71,14 +71,14 @@ class TestDownloadWeightFeature:
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
mock_isfile.return_value = True
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
mock_get_deepface_home.return_value = os.path.normpath("/mock/home")
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
source_url = "http://example.com/model_weights.zip"
|
||||
|
||||
result = weight_utils.download_weights_if_necessary(file_name, source_url)
|
||||
|
||||
assert result == os.path.join("/mock/home", ".deepface/weights", file_name)
|
||||
assert os.path.normpath(result) == os.path.normpath(os.path.join("/mock/home", ".deepface/weights", file_name))
|
||||
|
||||
mock_gdown.assert_not_called()
|
||||
mock_zipfile.assert_not_called()
|
||||
@ -96,7 +96,7 @@ class TestDownloadWeightFeature:
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
# Setting up the mock return values
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
mock_get_deepface_home.return_value = os.path.normpath("/mock/home")
|
||||
mock_isfile.return_value = False # Simulate file not being present
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
@ -125,7 +125,7 @@ class TestDownloadWeightFeature:
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
# Setting up the mock return values
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
mock_get_deepface_home.return_value = os.path.normpath("/mock/home")
|
||||
mock_isfile.return_value = False # Simulate file not being present
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
@ -134,13 +134,16 @@ class TestDownloadWeightFeature:
|
||||
# Call the function
|
||||
result = weight_utils.download_weights_if_necessary(file_name, source_url)
|
||||
|
||||
# Normalize the expected path
|
||||
expected_path = os.path.normpath("/mock/home/.deepface/weights/model_weights.h5")
|
||||
|
||||
# Assert that gdown.download was called with the correct parameters
|
||||
mock_gdown.assert_called_once_with(
|
||||
source_url, "/mock/home/.deepface/weights/model_weights.h5", quiet=False
|
||||
source_url, expected_path, quiet=False
|
||||
)
|
||||
|
||||
# Assert that the return value is correct
|
||||
assert result == "/mock/home/.deepface/weights/model_weights.h5"
|
||||
assert result == expected_path
|
||||
|
||||
# Assert that zipfile.ZipFile and bz2.BZ2File were not called
|
||||
mock_zipfile.assert_not_called()
|
||||
@ -159,7 +162,7 @@ class TestDownloadWeightFeature:
|
||||
mock_get_deepface_home,
|
||||
):
|
||||
# Setting up the mock return values
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
mock_get_deepface_home.return_value = os.path.normpath("/mock/home")
|
||||
mock_isfile.return_value = False # Simulate file not being present
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
@ -171,7 +174,7 @@ class TestDownloadWeightFeature:
|
||||
|
||||
# Assert that gdown.download was called with the correct parameters
|
||||
mock_gdown.assert_called_once_with(
|
||||
source_url, "/mock/home/.deepface/weights/model_weights.h5.zip", quiet=False
|
||||
source_url, os.path.normpath("/mock/home/.deepface/weights/model_weights.h5.zip"), quiet=False
|
||||
)
|
||||
|
||||
# Simulate the unzipping behavior
|
||||
@ -179,13 +182,13 @@ class TestDownloadWeightFeature:
|
||||
|
||||
# Call the function again to simulate unzipping
|
||||
with mock_zipfile.return_value as zip_ref:
|
||||
zip_ref.extractall("/mock/home/.deepface/weights")
|
||||
zip_ref.extractall(os.path.normpath("/mock/home/.deepface/weights"))
|
||||
|
||||
# Assert that the zip file was unzipped correctly
|
||||
zip_ref.extractall.assert_called_once_with("/mock/home/.deepface/weights")
|
||||
zip_ref.extractall.assert_called_once_with(os.path.normpath("/mock/home/.deepface/weights"))
|
||||
|
||||
# Assert that the return value is correct
|
||||
assert result == "/mock/home/.deepface/weights/model_weights.h5"
|
||||
assert result == os.path.normpath("/mock/home/.deepface/weights/model_weights.h5")
|
||||
|
||||
logger.info("✅ test download weights for zip is done")
|
||||
|
||||
@ -201,7 +204,7 @@ class TestDownloadWeightFeature:
|
||||
):
|
||||
|
||||
# Setting up the mock return values
|
||||
mock_get_deepface_home.return_value = "/mock/home"
|
||||
mock_get_deepface_home.return_value = os.path.normpath("/mock/home")
|
||||
mock_isfile.return_value = False # Simulate file not being present
|
||||
|
||||
file_name = "model_weights.h5"
|
||||
@ -219,16 +222,16 @@ class TestDownloadWeightFeature:
|
||||
|
||||
# Assert that gdown.download was called with the correct parameters
|
||||
mock_gdown.assert_called_once_with(
|
||||
source_url, "/mock/home/.deepface/weights/model_weights.h5.bz2", quiet=False
|
||||
source_url, os.path.normpath("/mock/home/.deepface/weights/model_weights.h5.bz2"), quiet=False
|
||||
)
|
||||
|
||||
# Ensure open() is called once for writing the decompressed data
|
||||
mock_open.assert_called_once_with("/mock/home/.deepface/weights/model_weights.h5", "wb")
|
||||
mock_open.assert_called_once_with(os.path.normpath("/mock/home/.deepface/weights/model_weights.h5"), "wb")
|
||||
|
||||
# TODO: find a way to check write is called
|
||||
|
||||
# Assert that the return value is correct
|
||||
assert result == "/mock/home/.deepface/weights/model_weights.h5"
|
||||
assert result == os.path.normpath("/mock/home/.deepface/weights/model_weights.h5")
|
||||
|
||||
logger.info("✅ test download weights for bz2 is done")
|
||||
|
||||
|
103
tests/test_find_batched.py
Normal file
103
tests/test_find_batched.py
Normal file
@ -0,0 +1,103 @@
|
||||
# built-in dependencies
|
||||
import os
|
||||
|
||||
# 3rd party dependencies
|
||||
import cv2
|
||||
|
||||
# project dependencies
|
||||
from deepface import DeepFace
|
||||
from deepface.modules import verification
|
||||
from deepface.commons.logger import Logger
|
||||
|
||||
logger = Logger()
|
||||
|
||||
|
||||
threshold = verification.find_threshold(model_name="VGG-Face", distance_metric="cosine")
|
||||
|
||||
|
||||
def test_find_with_exact_path():
|
||||
img_path = os.path.join("dataset", "img1.jpg")
|
||||
results = DeepFace.find(img_path=img_path, db_path="dataset", silent=True, batched=True)
|
||||
assert len(results) > 0
|
||||
required_keys = set([
|
||||
"identity", "distance", "threshold", "hash",
|
||||
"target_x", "target_y", "target_w", "target_h",
|
||||
"source_x", "source_y", "source_w", "source_h"
|
||||
])
|
||||
for result in results:
|
||||
assert isinstance(result, list)
|
||||
|
||||
found_image_itself = False
|
||||
for face in result:
|
||||
assert isinstance(face, dict)
|
||||
assert set(face.keys()) == required_keys
|
||||
if face["identity"] == img_path:
|
||||
# validate reproducability
|
||||
assert face["distance"] < threshold
|
||||
# one is img1.jpg itself
|
||||
found_image_itself = True
|
||||
assert found_image_itself
|
||||
|
||||
assert len(results[0]) > 1
|
||||
|
||||
logger.info("✅ test find for exact path done")
|
||||
|
||||
|
||||
def test_find_with_array_input():
|
||||
img_path = os.path.join("dataset", "img1.jpg")
|
||||
img1 = cv2.imread(img_path)
|
||||
results = DeepFace.find(img1, db_path="dataset", silent=True, batched=True)
|
||||
assert len(results) > 0
|
||||
for result in results:
|
||||
assert isinstance(result, list)
|
||||
|
||||
found_image_itself = False
|
||||
for face in result:
|
||||
assert isinstance(face, dict)
|
||||
if face["identity"] == img_path:
|
||||
# validate reproducability
|
||||
assert face["distance"] < threshold
|
||||
# one is img1.jpg itself
|
||||
found_image_itself = True
|
||||
assert found_image_itself
|
||||
|
||||
assert len(results[0]) > 1
|
||||
|
||||
logger.info("✅ test find for array input done")
|
||||
|
||||
|
||||
def test_find_with_extracted_faces():
|
||||
img_path = os.path.join("dataset", "img1.jpg")
|
||||
face_objs = DeepFace.extract_faces(img_path)
|
||||
img = face_objs[0]["face"]
|
||||
results = DeepFace.find(img, db_path="dataset", detector_backend="skip", silent=True, batched=True)
|
||||
assert len(results) > 0
|
||||
for result in results:
|
||||
assert isinstance(result, list)
|
||||
|
||||
found_image_itself = False
|
||||
for face in result:
|
||||
assert isinstance(face, dict)
|
||||
if face["identity"] == img_path:
|
||||
# validate reproducability
|
||||
assert face["distance"] < threshold
|
||||
# one is img1.jpg itself
|
||||
found_image_itself = True
|
||||
assert found_image_itself
|
||||
|
||||
assert len(results[0]) > 1
|
||||
logger.info("✅ test find for extracted face input done")
|
||||
|
||||
|
||||
def test_filetype_for_find():
|
||||
"""
|
||||
only images as jpg and png can be loaded into database
|
||||
"""
|
||||
img_path = os.path.join("dataset", "img1.jpg")
|
||||
results = DeepFace.find(img_path=img_path, db_path="dataset", silent=True, batched=True)
|
||||
|
||||
result = results[0]
|
||||
|
||||
assert not any(face["identity"] == "dataset/img47.jpg" for face in result)
|
||||
|
||||
logger.info("✅ test wrong filetype done")
|
Loading…
x
Reference in New Issue
Block a user