From ad0cbaf2dce255e488440f0ca9168012a8a4a05a Mon Sep 17 00:00:00 2001 From: kremnik Date: Mon, 30 Sep 2024 12:33:19 +0300 Subject: [PATCH] Add batched version of the find function --- deepface/DeepFace.py | 36 +++-- deepface/modules/recognition.py | 271 +++++++++++++++++++++++++++++++- 2 files changed, 287 insertions(+), 20 deletions(-) diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 5848d7b..7d7e81f 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -276,7 +276,8 @@ def find( silent: bool = False, refresh_database: bool = True, anti_spoofing: bool = False, -) -> List[pd.DataFrame]: + batched: bool = False, +) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]: """ Identify individuals in a database Args: @@ -322,22 +323,32 @@ def find( anti_spoofing (boolean): Flag to enable anti spoofing (default is False). Returns: - results (List[pd.DataFrame]): A list of pandas dataframes. Each dataframe corresponds - to the identity information for an individual detected in the source image. - The DataFrame columns include: + results (List[pd.DataFrame] or List[List[Dict[str, Any]]]): + A list of pandas dataframes (if `batched=False`) or + a list of dicts (if `batched=True`). + Each dataframe or dict corresponds to the identity information for + an individual detected in the source image. - - 'identity': Identity label of the detected individual. + Note: If you have a large database and/or a source photo with many faces, + use `batched=True`, as it is optimized for large batch processing. + Please pay attention that when using `batched=True`, the function returns + a list of dicts (not a list of DataFrames), + but with the same keys as the columns in the DataFrame. + + The DataFrame columns or dict keys include: - - 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the - target face in the database. + - 'identity': Identity label of the detected individual. - - 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the - detected face in the source image. + - 'target_x', 'target_y', 'target_w', 'target_h': Bounding box coordinates of the + target face in the database. - - 'threshold': threshold to determine a pair whether same person or different persons + - 'source_x', 'source_y', 'source_w', 'source_h': Bounding box coordinates of the + detected face in the source image. - - 'distance': Similarity score between the faces based on the - specified model and distance metric + - 'threshold': threshold to determine a pair whether same person or different persons + + - 'distance': Similarity score between the faces based on the + specified model and distance metric """ return recognition.find( img_path=img_path, @@ -353,6 +364,7 @@ def find( silent=silent, refresh_database=refresh_database, anti_spoofing=anti_spoofing, + batched=batched ) diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index 799dfbc..664b1cc 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -31,7 +31,8 @@ def find( silent: bool = False, refresh_database: bool = True, anti_spoofing: bool = False, -) -> List[pd.DataFrame]: + batched: bool = False, +) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]: """ Identify individuals in a database @@ -77,9 +78,19 @@ def find( Returns: - results (List[pd.DataFrame]): A list of pandas dataframes. Each dataframe corresponds - to the identity information for an individual detected in the source image. - The DataFrame columns include: + results (List[pd.DataFrame] or List[List[Dict[str, Any]]]): + A list of pandas dataframes (if `batched=False`) or + a list of dicts (if `batched=True`). + Each dataframe or dict corresponds to the identity information for + an individual detected in the source image. + + Note: If you have a large database and/or a source photo with many faces, + use `batched=True`, as it is optimized for large batch processing. + Please pay attention that when using `batched=True`, the function returns + a list of dicts (not a list of DataFrames), + but with the same keys as the columns in the DataFrame. + + The DataFrame columns or dict keys include: - 'identity': Identity label of the detected individual. @@ -233,10 +244,6 @@ def find( # ---------------------------- # now, we got representations for facial database - df = pd.DataFrame(representations) - - if silent is False: - logger.info(f"Searching {img_path} in {df.shape[0]} length datastore") # img path might have more than once face source_objs = detection.extract_faces( @@ -249,6 +256,24 @@ def find( anti_spoofing=anti_spoofing, ) + if batched: + return find_batched( + representations, + source_objs, + model_name, + distance_metric, + enforce_detection, + align, + threshold, + normalization, + anti_spoofing + ) + + df = pd.DataFrame(representations) + + if silent is False: + logger.info(f"Searching {img_path} in {df.shape[0]} length datastore") + resp_obj = [] for source_obj in source_objs: @@ -415,3 +440,233 @@ def __find_bulk_embeddings( ) return representations + +def find_batched( + representations: List[Dict[str, Any]], + source_objs: List[Dict[str, Any]], + model_name: str = "VGG-Face", + distance_metric: str = "cosine", + enforce_detection: bool = True, + align: bool = True, + threshold: Optional[float] = None, + normalization: str = "base", + anti_spoofing: bool = False, +) -> List[List[Dict[str, Any]]]: + """ + Perform batched face recognition by comparing source face embeddingswith a set of + target embeddings. It calculates pairwise distances between the source and target + embeddings using the specified distance metric. + The function uses batch processing for efficient computation of distances. + + Args: + representations (List[Dict[str, Any]]): + A list of dictionaries containing precomputed target embeddings and associated metadata. + Each dictionary should have at least the key `embedding`. + + source_objs (List[Dict[str, Any]]): + A list of dictionaries representing the source images to compare against + the target embeddings. Each dictionary should contain: + - `face`: The image data or path to the source face image. + - `facial_area`: A dictionary with keys `x`, `y`, `w`, `h` + indicating the facial region. + - Optionally, `is_real`: A boolean indicating if the face is real + (used for anti-spoofing). + + model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + + distance_metric (string): Metric for measuring similarity. Options: 'cosine', + 'euclidean', 'euclidean_l2'. + + enforce_detection (boolean): If no face is detected in an image, raise an exception. + Default is True. Set to False to avoid the exception for low-resolution images. + + detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. + + align (boolean): Perform alignment based on the eye positions. + + threshold (float): Specify a threshold to determine whether a pair represents the same + person or different individuals. This threshold is used for comparing distances. + If left unset, default pre-tuned threshold values will be applied based on the specified + model name and distance metric (default is None). + + normalization (string): Normalize the input image before feeding it to the model. + Default is base. Options: base, raw, Facenet, Facenet2018, VGGFace, VGGFace2, ArcFace + + silent (boolean): Suppress or allow some log messages for a quieter analysis process. + + anti_spoofing (boolean): Flag to enable anti spoofing (default is False). + + Returns: + List[List[Dict[str, Any]]]: + A list where each element corresponds to a source face and + contains a list of dictionaries with matching faces. + """ + embeddings_list = [] + valid_mask = [] + other_keys = set() + + for item in representations: + emb = item.get('embedding') + if emb is not None: + embeddings_list.append(emb) + valid_mask.append(True) + else: + embeddings_list.append(np.zeros_like(representations[0]['embedding'])) + valid_mask.append(False) + + other_keys.update(item.keys()) + + # remove embedding key from other keys + other_keys.discard('embedding') + other_keys = list(other_keys) + + embeddings = np.array(embeddings_list) # (N, D) + valid_mask = np.array(valid_mask) # (N,) + + data = { + key: np.array([item.get(key, None) for item in representations]) + for key in other_keys + } + + target_embeddings = [] + source_regions = [] + target_thresholds = [] + + for source_obj in source_objs: + if anti_spoofing and not source_obj.get("is_real", True): + raise ValueError("Spoof detected in the given image.") + + source_img = source_obj["face"] + source_region = source_obj["facial_area"] + + target_embedding_obj = representation.represent( + img_path=source_img, + model_name=model_name, + enforce_detection=enforce_detection, + detector_backend="skip", + align=align, + normalization=normalization, + ) + target_representation = target_embedding_obj[0]["embedding"] + + target_embeddings.append(target_representation) + source_regions.append(source_region) + + target_threshold = threshold or verification.find_threshold(model_name, distance_metric) + target_thresholds.append(target_threshold) + + target_embeddings = np.array(target_embeddings) # (M, D) + target_thresholds = np.array(target_thresholds) # (M,) + source_regions_arr = { + 'source_x': np.array([region['x'] for region in source_regions]), + 'source_y': np.array([region['y'] for region in source_regions]), + 'source_w': np.array([region['w'] for region in source_regions]), + 'source_h': np.array([region['h'] for region in source_regions]), + } + + def l2_normalize( + x: np.ndarray, axis: int = 1, epsilon: float = 1e-10 + ) -> np.ndarray: + """ + Normalize input vectors along a specified axis using L2 normalization + Args: + x (np.ndarray): input array + axis (int): axis along which to normalize + epsilon (float): small value to avoid division by zero + Returns: + np.ndarray: L2-normalized array of the same shape as input + """ + norm = np.linalg.norm(x, axis=axis, keepdims=True) + return x / (norm + epsilon) + + def find_cosine_distance_batch( + embeddings: np.ndarray, target_embeddings: np.ndarray + ) -> np.ndarray: + """ + Find the cosine distances between batches of embeddings + Args: + embeddings (np.ndarray): array of shape (N, D) + target_embeddings (np.ndarray): array of shape (M, D) + Returns: + np.ndarray: distance matrix of shape (M, N) + """ + embeddings_norm = l2_normalize(embeddings, axis=1) + target_embeddings_norm = l2_normalize(target_embeddings, axis=1) + cosine_similarities = np.dot(target_embeddings_norm, embeddings_norm.T) + cosine_distances = 1 - cosine_similarities + return cosine_distances + + def find_euclidean_distance_batch( + embeddings: np.ndarray, target_embeddings: np.ndarray + ) -> np.ndarray: + """ + Find the Euclidean distances between batches of embeddings + Args: + embeddings (np.ndarray): array of shape (N, D) + target_embeddings (np.ndarray): array of shape (M, D) + Returns: + np.ndarray: distance matrix of shape (M, N) + """ + diff = embeddings[None, :, :] - target_embeddings[:, None, :] # (M, N, D) + distances = np.linalg.norm(diff, axis=2) # (M, N) + return distances + + def find_distance_batch( + embeddings: np.ndarray, target_embeddings: np.ndarray, distance_metric: str, + ) -> np.ndarray: + """ + Find pairwise distances between batches of embeddings using the specified distance metric + Args: + embeddings (np.ndarray): array of shape (N, D) + target_embeddings (np.ndarray): array of shape (M, D) + distance_metric (str): distance metric ('cosine', 'euclidean', 'euclidean_l2') + Returns: + np.ndarray: distance matrix of shape (M, N) + """ + if distance_metric == "cosine": + distances = find_cosine_distance_batch(embeddings, target_embeddings) + elif distance_metric == "euclidean": + distances = find_euclidean_distance_batch(embeddings, target_embeddings) + elif distance_metric == "euclidean_l2": + embeddings_norm = l2_normalize(embeddings, axis=1) + target_embeddings_norm = l2_normalize(target_embeddings, axis=1) + distances = find_euclidean_distance_batch(embeddings_norm, target_embeddings_norm) + else: + raise ValueError("Invalid distance_metric passed - ", distance_metric) + return distances + + distances = find_distance_batch(embeddings, target_embeddings, distance_metric) # (M, N) + distances[:, ~valid_mask] = np.inf + + resp_obj = [] + + for i in range(len(target_embeddings)): + target_distances = distances[i] # (N,) + target_threshold = target_thresholds[i] + + N = embeddings.shape[0] + result_data = dict(data) + result_data.update({ + 'source_x': np.full(N, source_regions_arr['source_x'][i]), + 'source_y': np.full(N, source_regions_arr['source_y'][i]), + 'source_w': np.full(N, source_regions_arr['source_w'][i]), + 'source_h': np.full(N, source_regions_arr['source_h'][i]), + 'threshold': np.full(N, target_threshold), + 'distance': target_distances, + }) + + mask = target_distances <= target_threshold + filtered_data = {key: value[mask] for key, value in result_data.items()} + + sorted_indices = np.argsort(filtered_data['distance']) + sorted_data = {key: value[sorted_indices] for key, value in filtered_data.items()} + + num_results = len(sorted_data['distance']) + result_dicts = [ + {key: sorted_data[key][i] for key in sorted_data} + for i in range(num_results) + ] + resp_obj.append(result_dicts) + return resp_obj