mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
fix l2_normalize and tests
This commit is contained in:
parent
ad0cbaf2dc
commit
3bb317ac63
@ -453,7 +453,7 @@ def find_batched(
|
|||||||
anti_spoofing: bool = False,
|
anti_spoofing: bool = False,
|
||||||
) -> List[List[Dict[str, Any]]]:
|
) -> List[List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Perform batched face recognition by comparing source face embeddingswith a set of
|
Perform batched face recognition by comparing source face embeddings with a set of
|
||||||
target embeddings. It calculates pairwise distances between the source and target
|
target embeddings. It calculates pairwise distances between the source and target
|
||||||
embeddings using the specified distance metric.
|
embeddings using the specified distance metric.
|
||||||
The function uses batch processing for efficient computation of distances.
|
The function uses batch processing for efficient computation of distances.
|
||||||
@ -549,6 +549,7 @@ def find_batched(
|
|||||||
align=align,
|
align=align,
|
||||||
normalization=normalization,
|
normalization=normalization,
|
||||||
)
|
)
|
||||||
|
# it is safe to access 0 index because we already fed detected face to represent function
|
||||||
target_representation = target_embedding_obj[0]["embedding"]
|
target_representation = target_embedding_obj[0]["embedding"]
|
||||||
|
|
||||||
target_embeddings.append(target_representation)
|
target_embeddings.append(target_representation)
|
||||||
@ -566,21 +567,6 @@ def find_batched(
|
|||||||
'source_h': np.array([region['h'] for region in source_regions]),
|
'source_h': np.array([region['h'] for region in source_regions]),
|
||||||
}
|
}
|
||||||
|
|
||||||
def l2_normalize(
|
|
||||||
x: np.ndarray, axis: int = 1, epsilon: float = 1e-10
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Normalize input vectors along a specified axis using L2 normalization
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): input array
|
|
||||||
axis (int): axis along which to normalize
|
|
||||||
epsilon (float): small value to avoid division by zero
|
|
||||||
Returns:
|
|
||||||
np.ndarray: L2-normalized array of the same shape as input
|
|
||||||
"""
|
|
||||||
norm = np.linalg.norm(x, axis=axis, keepdims=True)
|
|
||||||
return x / (norm + epsilon)
|
|
||||||
|
|
||||||
def find_cosine_distance_batch(
|
def find_cosine_distance_batch(
|
||||||
embeddings: np.ndarray, target_embeddings: np.ndarray
|
embeddings: np.ndarray, target_embeddings: np.ndarray
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
@ -592,8 +578,8 @@ def find_batched(
|
|||||||
Returns:
|
Returns:
|
||||||
np.ndarray: distance matrix of shape (M, N)
|
np.ndarray: distance matrix of shape (M, N)
|
||||||
"""
|
"""
|
||||||
embeddings_norm = l2_normalize(embeddings, axis=1)
|
embeddings_norm = verification.l2_normalize(embeddings, axis=1)
|
||||||
target_embeddings_norm = l2_normalize(target_embeddings, axis=1)
|
target_embeddings_norm = verification.l2_normalize(target_embeddings, axis=1)
|
||||||
cosine_similarities = np.dot(target_embeddings_norm, embeddings_norm.T)
|
cosine_similarities = np.dot(target_embeddings_norm, embeddings_norm.T)
|
||||||
cosine_distances = 1 - cosine_similarities
|
cosine_distances = 1 - cosine_similarities
|
||||||
return cosine_distances
|
return cosine_distances
|
||||||
@ -630,12 +616,12 @@ def find_batched(
|
|||||||
elif distance_metric == "euclidean":
|
elif distance_metric == "euclidean":
|
||||||
distances = find_euclidean_distance_batch(embeddings, target_embeddings)
|
distances = find_euclidean_distance_batch(embeddings, target_embeddings)
|
||||||
elif distance_metric == "euclidean_l2":
|
elif distance_metric == "euclidean_l2":
|
||||||
embeddings_norm = l2_normalize(embeddings, axis=1)
|
embeddings_norm = verification.l2_normalize(embeddings, axis=1)
|
||||||
target_embeddings_norm = l2_normalize(target_embeddings, axis=1)
|
target_embeddings_norm = verification.l2_normalize(target_embeddings, axis=1)
|
||||||
distances = find_euclidean_distance_batch(embeddings_norm, target_embeddings_norm)
|
distances = find_euclidean_distance_batch(embeddings_norm, target_embeddings_norm)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
||||||
return distances
|
return np.round(distances, 6)
|
||||||
|
|
||||||
distances = find_distance_batch(embeddings, target_embeddings, distance_metric) # (M, N)
|
distances = find_distance_batch(embeddings, target_embeddings, distance_metric) # (M, N)
|
||||||
distances[:, ~valid_mask] = np.inf
|
distances[:, ~valid_mask] = np.inf
|
||||||
|
@ -304,18 +304,21 @@ def find_euclidean_distance(
|
|||||||
return np.linalg.norm(source_representation - test_representation)
|
return np.linalg.norm(source_representation - test_representation)
|
||||||
|
|
||||||
|
|
||||||
def l2_normalize(x: Union[np.ndarray, list]) -> np.ndarray:
|
def l2_normalize(
|
||||||
|
x: Union[np.ndarray, list], axis: Union[int, None] = None, epsilon: float = 1e-10
|
||||||
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Normalize input vector with l2
|
Normalize input vector with l2
|
||||||
Args:
|
Args:
|
||||||
x (np.ndarray or list): given vector
|
x (np.ndarray or list): given vector
|
||||||
|
axis (int): axis along which to normalize
|
||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray): l2 normalized vector
|
np.ndarray: l2 normalized vector
|
||||||
"""
|
"""
|
||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
x = np.array(x)
|
x = np.array(x)
|
||||||
norm = np.linalg.norm(x)
|
norm = np.linalg.norm(x, axis=axis, keepdims=True)
|
||||||
return x if norm == 0 else x / norm
|
return x / (norm + epsilon)
|
||||||
|
|
||||||
|
|
||||||
def find_distance(
|
def find_distance(
|
||||||
@ -341,7 +344,7 @@ def find_distance(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
raise ValueError("Invalid distance_metric passed - ", distance_metric)
|
||||||
return distance
|
return np.round(distance, 6)
|
||||||
|
|
||||||
|
|
||||||
def find_threshold(model_name: str, distance_metric: str) -> float:
|
def find_threshold(model_name: str, distance_metric: str) -> float:
|
||||||
|
@ -3,12 +3,10 @@ import os
|
|||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import cv2
|
import cv2
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface import DeepFace
|
from deepface import DeepFace
|
||||||
from deepface.modules import verification
|
from deepface.modules import verification
|
||||||
from deepface.commons import image_utils
|
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
logger = Logger()
|
logger = Logger()
|
||||||
@ -21,12 +19,18 @@ def test_find_with_exact_path():
|
|||||||
img_path = os.path.join("dataset", "img1.jpg")
|
img_path = os.path.join("dataset", "img1.jpg")
|
||||||
results = DeepFace.find(img_path=img_path, db_path="dataset", silent=True, batched=True)
|
results = DeepFace.find(img_path=img_path, db_path="dataset", silent=True, batched=True)
|
||||||
assert len(results) > 0
|
assert len(results) > 0
|
||||||
|
required_keys = set([
|
||||||
|
"identity", "distance", "threshold", "hash",
|
||||||
|
"target_x", "target_y", "target_w", "target_h",
|
||||||
|
"source_x", "source_y", "source_w", "source_h"
|
||||||
|
])
|
||||||
for result in results:
|
for result in results:
|
||||||
assert isinstance(result, list)
|
assert isinstance(result, list)
|
||||||
|
|
||||||
found_image_itself = False
|
found_image_itself = False
|
||||||
for face in result:
|
for face in result:
|
||||||
assert isinstance(face, dict)
|
assert isinstance(face, dict)
|
||||||
|
assert set(face.keys()) == required_keys
|
||||||
if face["identity"] == img_path:
|
if face["identity"] == img_path:
|
||||||
# validate reproducability
|
# validate reproducability
|
||||||
assert face["distance"] < threshold
|
assert face["distance"] < threshold
|
||||||
|
Loading…
x
Reference in New Issue
Block a user