fix l2_normalize and tests

This commit is contained in:
kremnik 2024-09-30 15:02:58 +03:00
parent ad0cbaf2dc
commit 3bb317ac63
3 changed files with 21 additions and 28 deletions

View File

@ -453,7 +453,7 @@ def find_batched(
anti_spoofing: bool = False,
) -> List[List[Dict[str, Any]]]:
"""
Perform batched face recognition by comparing source face embeddingswith a set of
Perform batched face recognition by comparing source face embeddings with a set of
target embeddings. It calculates pairwise distances between the source and target
embeddings using the specified distance metric.
The function uses batch processing for efficient computation of distances.
@ -549,6 +549,7 @@ def find_batched(
align=align,
normalization=normalization,
)
# it is safe to access 0 index because we already fed detected face to represent function
target_representation = target_embedding_obj[0]["embedding"]
target_embeddings.append(target_representation)
@ -566,21 +567,6 @@ def find_batched(
'source_h': np.array([region['h'] for region in source_regions]),
}
def l2_normalize(
x: np.ndarray, axis: int = 1, epsilon: float = 1e-10
) -> np.ndarray:
"""
Normalize input vectors along a specified axis using L2 normalization
Args:
x (np.ndarray): input array
axis (int): axis along which to normalize
epsilon (float): small value to avoid division by zero
Returns:
np.ndarray: L2-normalized array of the same shape as input
"""
norm = np.linalg.norm(x, axis=axis, keepdims=True)
return x / (norm + epsilon)
def find_cosine_distance_batch(
embeddings: np.ndarray, target_embeddings: np.ndarray
) -> np.ndarray:
@ -592,8 +578,8 @@ def find_batched(
Returns:
np.ndarray: distance matrix of shape (M, N)
"""
embeddings_norm = l2_normalize(embeddings, axis=1)
target_embeddings_norm = l2_normalize(target_embeddings, axis=1)
embeddings_norm = verification.l2_normalize(embeddings, axis=1)
target_embeddings_norm = verification.l2_normalize(target_embeddings, axis=1)
cosine_similarities = np.dot(target_embeddings_norm, embeddings_norm.T)
cosine_distances = 1 - cosine_similarities
return cosine_distances
@ -630,12 +616,12 @@ def find_batched(
elif distance_metric == "euclidean":
distances = find_euclidean_distance_batch(embeddings, target_embeddings)
elif distance_metric == "euclidean_l2":
embeddings_norm = l2_normalize(embeddings, axis=1)
target_embeddings_norm = l2_normalize(target_embeddings, axis=1)
embeddings_norm = verification.l2_normalize(embeddings, axis=1)
target_embeddings_norm = verification.l2_normalize(target_embeddings, axis=1)
distances = find_euclidean_distance_batch(embeddings_norm, target_embeddings_norm)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
return distances
return np.round(distances, 6)
distances = find_distance_batch(embeddings, target_embeddings, distance_metric) # (M, N)
distances[:, ~valid_mask] = np.inf

View File

@ -304,18 +304,21 @@ def find_euclidean_distance(
return np.linalg.norm(source_representation - test_representation)
def l2_normalize(x: Union[np.ndarray, list]) -> np.ndarray:
def l2_normalize(
x: Union[np.ndarray, list], axis: Union[int, None] = None, epsilon: float = 1e-10
) -> np.ndarray:
"""
Normalize input vector with l2
Args:
x (np.ndarray or list): given vector
axis (int): axis along which to normalize
Returns:
y (np.ndarray): l2 normalized vector
np.ndarray: l2 normalized vector
"""
if isinstance(x, list):
x = np.array(x)
norm = np.linalg.norm(x)
return x if norm == 0 else x / norm
norm = np.linalg.norm(x, axis=axis, keepdims=True)
return x / (norm + epsilon)
def find_distance(
@ -341,7 +344,7 @@ def find_distance(
)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
return distance
return np.round(distance, 6)
def find_threshold(model_name: str, distance_metric: str) -> float:

View File

@ -3,12 +3,10 @@ import os
# 3rd party dependencies
import cv2
import pandas as pd
# project dependencies
from deepface import DeepFace
from deepface.modules import verification
from deepface.commons import image_utils
from deepface.commons.logger import Logger
logger = Logger()
@ -21,12 +19,18 @@ def test_find_with_exact_path():
img_path = os.path.join("dataset", "img1.jpg")
results = DeepFace.find(img_path=img_path, db_path="dataset", silent=True, batched=True)
assert len(results) > 0
required_keys = set([
"identity", "distance", "threshold", "hash",
"target_x", "target_y", "target_w", "target_h",
"source_x", "source_y", "source_w", "source_h"
])
for result in results:
assert isinstance(result, list)
found_image_itself = False
for face in result:
assert isinstance(face, dict)
assert set(face.keys()) == required_keys
if face["identity"] == img_path:
# validate reproducability
assert face["distance"] < threshold