Add batched version of the find function

This commit is contained in:
kremnik 2024-09-30 12:33:19 +03:00
parent 937513453e
commit ad0cbaf2dc
2 changed files with 287 additions and 20 deletions

View File

@ -276,7 +276,8 @@ def find(
silent: bool = False,
refresh_database: bool = True,
anti_spoofing: bool = False,
) -> List[pd.DataFrame]:
batched: bool = False,
) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]:
"""
Identify individuals in a database
Args:
@ -322,22 +323,32 @@ def find(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns:
results (List[pd.DataFrame]): A list of pandas dataframes. Each dataframe corresponds
to the identity information for an individual detected in the source image.
The DataFrame columns include:
results (List[pd.DataFrame] or List[List[Dict[str, Any]]]):
A list of pandas dataframes (if `batched=False`) or
a list of dicts (if `batched=True`).
Each dataframe or dict corresponds to the identity information for
an individual detected in the source image.
- 'identity': Identity label of the detected individual.
Note: If you have a large database and/or a source photo with many faces,
use `batched=True`, as it is optimized for large batch processing.
Please pay attention that when using `batched=True`, the function returns
a list of dicts (not a list of DataFrames),
but with the same keys as the columns in the DataFrame.
The DataFrame columns or dict keys include:
- 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the
target face in the database.
- 'identity': Identity label of the detected individual.
- 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the
detected face in the source image.
- 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the
target face in the database.
- 'threshold': threshold to determine a pair whether same person or different persons
- 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the
detected face in the source image.
- 'distance': Similarity score between the faces based on the
specified model and distance metric
- 'threshold': threshold to determine a pair whether same person or different persons
- 'distance': Similarity score between the faces based on the
specified model and distance metric
"""
return recognition.find(
img_path=img_path,
@ -353,6 +364,7 @@ def find(
silent=silent,
refresh_database=refresh_database,
anti_spoofing=anti_spoofing,
batched=batched
)

View File

@ -31,7 +31,8 @@ def find(
silent: bool = False,
refresh_database: bool = True,
anti_spoofing: bool = False,
) -> List[pd.DataFrame]:
batched: bool = False,
) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]:
"""
Identify individuals in a database
@ -77,9 +78,19 @@ def find(
Returns:
results (List[pd.DataFrame]): A list of pandas dataframes. Each dataframe corresponds
to the identity information for an individual detected in the source image.
The DataFrame columns include:
results (List[pd.DataFrame] or List[List[Dict[str, Any]]]):
A list of pandas dataframes (if `batched=False`) or
a list of dicts (if `batched=True`).
Each dataframe or dict corresponds to the identity information for
an individual detected in the source image.
Note: If you have a large database and/or a source photo with many faces,
use `batched=True`, as it is optimized for large batch processing.
Please pay attention that when using `batched=True`, the function returns
a list of dicts (not a list of DataFrames),
but with the same keys as the columns in the DataFrame.
The DataFrame columns or dict keys include:
- 'identity': Identity label of the detected individual.
@ -233,10 +244,6 @@ def find(
# ----------------------------
# now, we got representations for facial database
df = pd.DataFrame(representations)
if silent is False:
logger.info(f"Searching {img_path} in {df.shape[0]} length datastore")
# img path might have more than once face
source_objs = detection.extract_faces(
@ -249,6 +256,24 @@ def find(
anti_spoofing=anti_spoofing,
)
if batched:
return find_batched(
representations,
source_objs,
model_name,
distance_metric,
enforce_detection,
align,
threshold,
normalization,
anti_spoofing
)
df = pd.DataFrame(representations)
if silent is False:
logger.info(f"Searching {img_path} in {df.shape[0]} length datastore")
resp_obj = []
for source_obj in source_objs:
@ -415,3 +440,233 @@ def __find_bulk_embeddings(
)
return representations
def find_batched(
representations: List[Dict[str, Any]],
source_objs: List[Dict[str, Any]],
model_name: str = "VGG-Face",
distance_metric: str = "cosine",
enforce_detection: bool = True,
align: bool = True,
threshold: Optional[float] = None,
normalization: str = "base",
anti_spoofing: bool = False,
) -> List[List[Dict[str, Any]]]:
"""
Perform batched face recognition by comparing source face embeddingswith a set of
target embeddings. It calculates pairwise distances between the source and target
embeddings using the specified distance metric.
The function uses batch processing for efficient computation of distances.
Args:
representations (List[Dict[str, Any]]):
A list of dictionaries containing precomputed target embeddings and associated metadata.
Each dictionary should have at least the key `embedding`.
source_objs (List[Dict[str, Any]]):
A list of dictionaries representing the source images to compare against
the target embeddings. Each dictionary should contain:
- `face`: The image data or path to the source face image.
- `facial_area`: A dictionary with keys `x`, `y`, `w`, `h`
indicating the facial region.
- Optionally, `is_real`: A boolean indicating if the face is real
(used for anti-spoofing).
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
distance_metric (string): Metric for measuring similarity. Options: 'cosine',
'euclidean', 'euclidean_l2'.
enforce_detection (boolean): If no face is detected in an image, raise an exception.
Default is True. Set to False to avoid the exception for low-resolution images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'.
align (boolean): Perform alignment based on the eye positions.
threshold (float): Specify a threshold to determine whether a pair represents the same
person or different individuals. This threshold is used for comparing distances.
If left unset, default pre-tuned threshold values will be applied based on the specified
model name and distance metric (default is None).
normalization (string): Normalize the input image before feeding it to the model.
Default is base. Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace
silent (boolean): Suppress or allow some log messages for a quieter analysis process.
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns:
List[List[Dict[str, Any]]]:
A list where each element corresponds to a source face and
contains a list of dictionaries with matching faces.
"""
embeddings_list = []
valid_mask = []
other_keys = set()
for item in representations:
emb = item.get('embedding')
if emb is not None:
embeddings_list.append(emb)
valid_mask.append(True)
else:
embeddings_list.append(np.zeros_like(representations[0]['embedding']))
valid_mask.append(False)
other_keys.update(item.keys())
# remove embedding key from other keys
other_keys.discard('embedding')
other_keys = list(other_keys)
embeddings = np.array(embeddings_list) # (N, D)
valid_mask = np.array(valid_mask) # (N,)
data = {
key: np.array([item.get(key, None) for item in representations])
for key in other_keys
}
target_embeddings = []
source_regions = []
target_thresholds = []
for source_obj in source_objs:
if anti_spoofing and not source_obj.get("is_real", True):
raise ValueError("Spoof detected in the given image.")
source_img = source_obj["face"]
source_region = source_obj["facial_area"]
target_embedding_obj = representation.represent(
img_path=source_img,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
target_representation = target_embedding_obj[0]["embedding"]
target_embeddings.append(target_representation)
source_regions.append(source_region)
target_threshold = threshold or verification.find_threshold(model_name, distance_metric)
target_thresholds.append(target_threshold)
target_embeddings = np.array(target_embeddings) # (M, D)
target_thresholds = np.array(target_thresholds) # (M,)
source_regions_arr = {
'source_x': np.array([region['x'] for region in source_regions]),
'source_y': np.array([region['y'] for region in source_regions]),
'source_w': np.array([region['w'] for region in source_regions]),
'source_h': np.array([region['h'] for region in source_regions]),
}
def l2_normalize(
x: np.ndarray, axis: int = 1, epsilon: float = 1e-10
) -> np.ndarray:
"""
Normalize input vectors along a specified axis using L2 normalization
Args:
x (np.ndarray): input array
axis (int): axis along which to normalize
epsilon (float): small value to avoid division by zero
Returns:
np.ndarray: L2-normalized array of the same shape as input
"""
norm = np.linalg.norm(x, axis=axis, keepdims=True)
return x / (norm + epsilon)
def find_cosine_distance_batch(
embeddings: np.ndarray, target_embeddings: np.ndarray
) -> np.ndarray:
"""
Find the cosine distances between batches of embeddings
Args:
embeddings (np.ndarray): array of shape (N, D)
target_embeddings (np.ndarray): array of shape (M, D)
Returns:
np.ndarray: distance matrix of shape (M, N)
"""
embeddings_norm = l2_normalize(embeddings, axis=1)
target_embeddings_norm = l2_normalize(target_embeddings, axis=1)
cosine_similarities = np.dot(target_embeddings_norm, embeddings_norm.T)
cosine_distances = 1 - cosine_similarities
return cosine_distances
def find_euclidean_distance_batch(
embeddings: np.ndarray, target_embeddings: np.ndarray
) -> np.ndarray:
"""
Find the Euclidean distances between batches of embeddings
Args:
embeddings (np.ndarray): array of shape (N, D)
target_embeddings (np.ndarray): array of shape (M, D)
Returns:
np.ndarray: distance matrix of shape (M, N)
"""
diff = embeddings[None, :, :] - target_embeddings[:, None, :] # (M, N, D)
distances = np.linalg.norm(diff, axis=2) # (M, N)
return distances
def find_distance_batch(
embeddings: np.ndarray, target_embeddings: np.ndarray, distance_metric: str,
) -> np.ndarray:
"""
Find pairwise distances between batches of embeddings using the specified distance metric
Args:
embeddings (np.ndarray): array of shape (N, D)
target_embeddings (np.ndarray): array of shape (M, D)
distance_metric (str): distance metric ('cosine', 'euclidean', 'euclidean_l2')
Returns:
np.ndarray: distance matrix of shape (M, N)
"""
if distance_metric == "cosine":
distances = find_cosine_distance_batch(embeddings, target_embeddings)
elif distance_metric == "euclidean":
distances = find_euclidean_distance_batch(embeddings, target_embeddings)
elif distance_metric == "euclidean_l2":
embeddings_norm = l2_normalize(embeddings, axis=1)
target_embeddings_norm = l2_normalize(target_embeddings, axis=1)
distances = find_euclidean_distance_batch(embeddings_norm, target_embeddings_norm)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
return distances
distances = find_distance_batch(embeddings, target_embeddings, distance_metric) # (M, N)
distances[:, ~valid_mask] = np.inf
resp_obj = []
for i in range(len(target_embeddings)):
target_distances = distances[i] # (N,)
target_threshold = target_thresholds[i]
N = embeddings.shape[0]
result_data = dict(data)
result_data.update({
'source_x': np.full(N, source_regions_arr['source_x'][i]),
'source_y': np.full(N, source_regions_arr['source_y'][i]),
'source_w': np.full(N, source_regions_arr['source_w'][i]),
'source_h': np.full(N, source_regions_arr['source_h'][i]),
'threshold': np.full(N, target_threshold),
'distance': target_distances,
})
mask = target_distances <= target_threshold
filtered_data = {key: value[mask] for key, value in result_data.items()}
sorted_indices = np.argsort(filtered_data['distance'])
sorted_data = {key: value[sorted_indices] for key, value in filtered_data.items()}
num_results = len(sorted_data['distance'])
result_dicts = [
{key: sorted_data[key][i] for key in sorted_data}
for i in range(num_results)
]
resp_obj.append(result_dicts)
return resp_obj