diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 99bcb64..4736474 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -206,9 +206,9 @@ def analyze( anti_spoofing (boolean): Flag to enable anti spoofing (default is False). Returns: - (List[List[Dict[str, Any]]]): A list of analysis results if received batched image, + (List[List[Dict[str, Any]]]): A list of analysis results if received batched image, explained below. - + (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents the analysis results for a detected face. Each dictionary in the list contains the following keys: @@ -385,12 +385,12 @@ def represent( normalization: str = "base", anti_spoofing: bool = False, max_faces: Optional[int] = None, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ Represent facial images as multi-dimensional vector embeddings. Args: - img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]): + img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]): The exact path to the image, a numpy array in BGR format, a file object that supports at least `.read` and is opened in binary mode, or a base64 encoded image. If the source image contains multiple faces, @@ -423,8 +423,9 @@ def represent( max_faces (int): Set a limit on the number of faces to be processed (default is None). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, each containing the - following fields: + results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries. + Result type becomes List of List of Dict if batch input passed. + Each containing the following fields: - embedding (List[float]): Multidimensional vector representing facial features. The number of dimensions varies based on the reference model diff --git a/deepface/models/Demography.py b/deepface/models/Demography.py index 1493059..0d8a2de 100644 --- a/deepface/models/Demography.py +++ b/deepface/models/Demography.py @@ -24,7 +24,7 @@ class Demography(ABC): def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray: """ Predict for single image or batched images. - This method uses legacy method while receiving single image as input. + This method uses legacy method while receiving single image as input. And switch to batch prediction if receives batched images. Args: @@ -35,11 +35,11 @@ class Demography(ABC): with x = image width, y = image height and c = channel The channel dimension will be 1 if input is grayscale. (For emotion model) """ - if not self.model_name: # Check if called from derived class + if not self.model_name: # Check if called from derived class raise NotImplementedError("no model selected") assert img_batch.ndim == 4, "expected 4-dimensional tensor input" - if img_batch.shape[0] == 1: # Single image + if img_batch.shape[0] == 1: # Single image # Predict with legacy method. return self.model(img_batch, training=False).numpy()[0, :] @@ -48,10 +48,8 @@ class Demography(ABC): return self.model.predict_on_batch(img_batch) def _preprocess_batch_or_single_input( - self, - img: Union[np.ndarray, List[np.ndarray]] + self, img: Union[np.ndarray, List[np.ndarray]] ) -> np.ndarray: - """ Preprocess single or batch of images, return as 4-D numpy array. Args: diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index c960159..f5a56c6 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -13,7 +13,6 @@ from deepface.commons.logger import Logger logger = Logger() -# ---------------------------------------- # dependency configurations tf_version = package_utils.get_tf_major_version() @@ -25,12 +24,11 @@ else: from tensorflow.keras.models import Model, Sequential from tensorflow.keras.layers import Convolution2D, Flatten, Activation -# ---------------------------------------- - WEIGHTS_URL = ( "https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5" ) + # pylint: disable=too-few-public-methods class ApparentAgeClient(Demography): """ @@ -49,7 +47,7 @@ class ApparentAgeClient(Demography): List of images as List[np.ndarray] or Batch of images as np.ndarray (n, 224, 224, 3) Returns: - np.ndarray (age_classes,) if single image, + np.ndarray (age_classes,) if single image, np.ndarray (n, age_classes) if batched images. """ # Preprocessing input image or image list. @@ -59,11 +57,10 @@ class ApparentAgeClient(Demography): age_predictions = self._predict_internal(imgs) # Calculate apparent ages - if len(age_predictions.shape) == 1: # Single prediction list + if len(age_predictions.shape) == 1: # Single prediction list return find_apparent_age(age_predictions) - return np.array([ - find_apparent_age(age_prediction) for age_prediction in age_predictions]) + return np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions]) def load_model( @@ -100,6 +97,7 @@ def load_model( return age_model + def find_apparent_age(age_predictions: np.ndarray) -> np.float64: """ Find apparent age prediction from a given probas of ages @@ -108,7 +106,9 @@ def find_apparent_age(age_predictions: np.ndarray) -> np.float64: Returns: apparent_age (float) """ - assert len(age_predictions.shape) == 1, f"Input should be a list of predictions, \ + assert ( + len(age_predictions.shape) == 1 + ), f"Input should be a list of predictions, \ not batched. Got shape: {age_predictions.shape}" output_indexes = np.arange(0, 101) apparent_age = np.sum(age_predictions * output_indexes) diff --git a/deepface/modules/demography.py b/deepface/modules/demography.py index 85cbe81..d3ce8e6 100644 --- a/deepface/modules/demography.py +++ b/deepface/modules/demography.py @@ -123,7 +123,6 @@ def analyze( batch_resp_obj.append(resp_obj) return batch_resp_obj - # if actions is passed as tuple with single item, interestingly it becomes str here if isinstance(actions, str): actions = (actions,) diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index 90e8c29..4002b47 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -398,6 +398,7 @@ def __find_bulk_embeddings( enforce_detection=enforce_detection, align=align, expand_percentage=expand_percentage, + color_face='bgr' # `represent` expects images in bgr format. ) except ValueError as err: diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index 56eaef2..2be4fec 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -20,12 +20,12 @@ def represent( normalization: str = "base", anti_spoofing: bool = False, max_faces: Optional[int] = None, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ Represent facial images as multi-dimensional vector embeddings. Args: - img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]): + img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]): The exact path to the image, a numpy array in BGR format, a base64 encoded image, or a sequence of these. If the source image contains multiple faces, @@ -53,8 +53,9 @@ def represent( max_faces (int): Set a limit on the number of faces to be processed (default is None). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, each containing the - following fields: + results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries. + Result type becomes List of List of Dict if batch input passed. + Each containing the following fields: - embedding (List[float]): Multidimensional vector representing facial features. The number of dimensions varies based on the reference model @@ -80,16 +81,13 @@ def represent( else: images = [img_path] - batch_images = [] - batch_regions = [] - batch_confidences = [] + batch_images, batch_regions, batch_confidences, batch_indexes = [], [], [], [] - for single_img_path in images: - # --------------------------------- - # we have run pre-process in verification. - # so, this can be skipped if it is coming from verify. + for idx, single_img_path in enumerate(images): + # we have run pre-process in verification. so, skip if it is coming from verify. target_size = model.input_shape if detector_backend != "skip": + # Images are returned in RGB format. img_objs = detection.extract_faces( img_path=single_img_path, detector_backend=detector_backend, @@ -107,6 +105,9 @@ def represent( if len(img.shape) != 3: raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}") + # Convert to RGB format to keep compatability with `extract_faces`. + img = img[:, :, ::-1] + # make dummy region and confidence to keep compatibility with `extract_faces` img_objs = [ { @@ -130,9 +131,10 @@ def represent( for img_obj in img_objs: if anti_spoofing is True and img_obj.get("is_real", True) is False: raise ValueError("Spoof detected in the given image.") + img = img_obj["face"] - # bgr to rgb + # rgb to bgr img = img[:, :, ::-1] region = img_obj["facial_area"] @@ -151,22 +153,25 @@ def represent( batch_images.append(img) batch_regions.append(region) batch_confidences.append(confidence) + batch_indexes.append(idx) # Convert list of images to a numpy array for batch processing batch_images = np.concatenate(batch_images, axis=0) # Forward pass through the model for the entire batch embeddings = model.forward(batch_images) - if len(batch_images) == 1: - embeddings = [embeddings] - for embedding, region, confidence in zip(embeddings, batch_regions, batch_confidences): - resp_objs.append( - { - "embedding": embedding, - "facial_area": region, - "face_confidence": confidence, - } - ) + for idx in range(0, len(images)): + resp_obj = [] + for idy, batch_index in enumerate(batch_indexes): + if idx == batch_index: + resp_obj.append( + { + "embedding": embeddings if len(batch_images) == 1 else embeddings[idy], + "facial_area": batch_regions[idy], + "face_confidence": batch_confidences[idy], + } + ) + resp_objs.append(resp_obj) - return resp_objs + return resp_objs[0] if len(images) == 1 else resp_objs diff --git a/tests/test_analyze.py b/tests/test_analyze.py index a36acc5..6f8c996 100644 --- a/tests/test_analyze.py +++ b/tests/test_analyze.py @@ -144,26 +144,39 @@ def test_analyze_for_different_detectors(): else: assert result["gender"]["Man"] < result["gender"]["Woman"] -def test_analyze_for_batched_image(): - img = "dataset/img4.jpg" + +def test_analyze_for_numpy_batched_image(): + img1_path = "dataset/img4.jpg" + img2_path = "dataset/couple.jpg" + # Copy and combine the same image to create multiple faces - img = cv2.imread(img) - img = np.stack([img, img]) - assert len(img.shape) == 4 # Check dimension. - assert img.shape[0] == 2 # Check batch size. + img1 = cv2.imread(img1_path) + img2 = cv2.imread(img2_path) + + expected_num_faces = [1, 2] + + img1 = cv2.resize(img1, (500, 500)) + img2 = cv2.resize(img2, (500, 500)) + + img = np.stack([img1, img2]) + assert len(img.shape) == 4 # Check dimension. + assert img.shape[0] == 2 # Check batch size. demography_batch = DeepFace.analyze(img, silent=True) # 2 image in batch, so 2 demography objects. - assert len(demography_batch) == 2 + assert len(demography_batch) == 2 - for demography_objs in demography_batch: - assert len(demography_objs) == 1 # 1 face in each image - for demography in demography_objs: # Iterate over faces - assert type(demography) == dict # Check type + for i, demography_objs in enumerate(demography_batch): + + assert len(demography_objs) == expected_num_faces[i] + for demography in demography_objs: # Iterate over faces + assert isinstance(demography, dict) # Check type assert demography["age"] > 20 and demography["age"] < 40 - assert demography["dominant_gender"] == "Woman" + assert demography["dominant_gender"] in ["Woman", "Man"] + logger.info("✅ test analyze for multiple faces done") + def test_batch_detect_age_for_multiple_faces(): # Load test image and resize to model input size img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) @@ -176,6 +189,7 @@ def test_batch_detect_age_for_multiple_faces(): assert np.array_equal(int(results[0]), int(results[1])) logger.info("✅ test batch detect age for multiple faces done") + def test_batch_detect_emotion_for_multiple_faces(): # Load test image and resize to model input size img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) @@ -187,6 +201,7 @@ def test_batch_detect_emotion_for_multiple_faces(): assert np.array_equal(results[0], results[1]) logger.info("✅ test batch detect emotion for multiple faces done") + def test_batch_detect_gender_for_multiple_faces(): # Load test image and resize to model input size img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) @@ -198,6 +213,7 @@ def test_batch_detect_gender_for_multiple_faces(): assert np.array_equal(results[0], results[1]) logger.info("✅ test batch detect gender for multiple faces done") + def test_batch_detect_race_for_multiple_faces(): # Load test image and resize to model input size img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) @@ -207,4 +223,4 @@ def test_batch_detect_race_for_multiple_faces(): assert len(results) == 2 # Check two races are the same assert np.array_equal(results[0], results[1]) - logger.info("✅ test batch detect race for multiple faces done") \ No newline at end of file + logger.info("✅ test batch detect race for multiple faces done") diff --git a/tests/test_represent.py b/tests/test_represent.py index bc83a4e..e5a7eab 100644 --- a/tests/test_represent.py +++ b/tests/test_represent.py @@ -15,7 +15,12 @@ logger = Logger() def test_standard_represent(): img_path = "dataset/img1.jpg" embedding_objs = DeepFace.represent(img_path) + # type should be list of dict + assert isinstance(embedding_objs, list) + for embedding_obj in embedding_objs: + assert isinstance(embedding_obj, dict) + embedding = embedding_obj["embedding"] logger.debug(f"Function returned {len(embedding)} dimensional vector") assert len(embedding) == 4096 @@ -25,18 +30,18 @@ def test_standard_represent(): def test_standard_represent_with_io_object(): img_path = "dataset/img1.jpg" default_embedding_objs = DeepFace.represent(img_path) - io_embedding_objs = DeepFace.represent(open(img_path, 'rb')) + io_embedding_objs = DeepFace.represent(open(img_path, "rb")) assert default_embedding_objs == io_embedding_objs # Confirm non-seekable io objects are handled properly - io_obj = io.BytesIO(open(img_path, 'rb').read()) + io_obj = io.BytesIO(open(img_path, "rb").read()) io_obj.seek = None no_seek_io_embedding_objs = DeepFace.represent(io_obj) assert default_embedding_objs == no_seek_io_embedding_objs # Confirm non-image io objects raise exceptions - with pytest.raises(ValueError, match='Failed to decode image'): - DeepFace.represent(io.BytesIO(open(r'../requirements.txt', 'rb').read())) + with pytest.raises(ValueError, match="Failed to decode image"): + DeepFace.represent(io.BytesIO(open(r"../requirements.txt", "rb").read())) logger.info("✅ test standard represent with io object function done") @@ -57,6 +62,27 @@ def test_represent_for_skipped_detector_backend_with_image_path(): logger.info("✅ test represent function for skipped detector and image path input backend done") +def test_represent_for_preloaded_image(): + face_img = "dataset/img5.jpg" + img = cv2.imread(face_img) + img_objs = DeepFace.represent(img_path=img) + # type should be list of dict + assert isinstance(img_objs, list) + assert len(img_objs) >= 1 + + for img_obj in img_objs: + assert isinstance(img_obj, dict) + assert "embedding" in img_obj.keys() + assert "facial_area" in img_obj.keys() + assert isinstance(img_obj["facial_area"], dict) + assert "x" in img_obj["facial_area"].keys() + assert "y" in img_obj["facial_area"].keys() + assert "w" in img_obj["facial_area"].keys() + assert "h" in img_obj["facial_area"].keys() + assert "face_confidence" in img_obj.keys() + logger.info("✅ test represent function for skipped detector and preloaded image done") + + def test_represent_for_skipped_detector_backend_with_preloaded_image(): face_img = "dataset/img5.jpg" img = cv2.imread(face_img) @@ -85,40 +111,127 @@ def test_max_faces(): assert len(results) == max_faces -@pytest.mark.parametrize("model_name", [ - "VGG-Face", - "Facenet", - "SFace", -]) -def test_batched_represent(model_name): +def test_represent_detector_backend(): + # Results using a detection backend. + results_1 = DeepFace.represent(img_path="dataset/img1.jpg") + assert len(results_1) == 1 + + # Results performing face extraction first. + faces = DeepFace.extract_faces(img_path="dataset/img1.jpg", color_face='bgr') + assert len(faces) == 1 + + # Images sent into represent need to be in BGR format. + img = faces[0]['face'] + results_2 = DeepFace.represent(img_path=img, detector_backend="skip") + assert len(results_2) == 1 + + # The embeddings should be the exact same for both cases. + embedding_1 = results_1[0]['embedding'] + embedding_2 = results_2[0]['embedding'] + assert embedding_1 == embedding_2 + logger.info("✅ test represent function for consistent output.") + + +@pytest.mark.parametrize( + "model_name", + [ + "VGG-Face", + "Facenet", + "SFace", + ], +) +def test_batched_represent_for_list_input(model_name): img_paths = [ "dataset/img1.jpg", "dataset/img2.jpg", "dataset/img3.jpg", "dataset/img4.jpg", "dataset/img5.jpg", + "dataset/couple.jpg", ] - embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name) - assert len(embedding_objs) == len(img_paths), f"Expected {len(img_paths)} embeddings, got {len(embedding_objs)}" + expected_faces = [1, 1, 1, 1, 1, 2] - if model_name == "VGG-Face": + batched_embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name) + + # type should be list of list of dict for batch input + assert isinstance(batched_embedding_objs, list) + + assert len(batched_embedding_objs) == len( + img_paths + ), f"Expected {len(img_paths)} embeddings, got {len(batched_embedding_objs)}" + + # the last one has two faces + for idx, embedding_objs in enumerate(batched_embedding_objs): + # type should be list of list of dict for batch input + # batched_embedding_objs was list already, embedding_objs should be list of dict + assert isinstance(embedding_objs, list) for embedding_obj in embedding_objs: - embedding = embedding_obj["embedding"] - logger.debug(f"Function returned {len(embedding)} dimensional vector") - assert len(embedding) == 4096, f"Expected embedding of length 4096, got {len(embedding)}" + assert isinstance(embedding_obj, dict) - embedding_objs_one_by_one = [ - embedding_obj - for img_path in img_paths - for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name) + assert expected_faces[idx] == len( + embedding_objs + ), f"{img_paths[idx]} has {expected_faces[idx]} faces, but got {len(embedding_objs)} embeddings!" + + for idx, img_path in enumerate(img_paths): + single_embedding_objs = DeepFace.represent(img_path=img_path, model_name=model_name) + # type should be list of dict for single input + assert isinstance(single_embedding_objs, list) + for embedding_obj in single_embedding_objs: + assert isinstance(embedding_obj, dict) + + assert len(single_embedding_objs) == len(batched_embedding_objs[idx]) + + for alpha, beta in zip(single_embedding_objs, batched_embedding_objs[idx]): + assert np.allclose( + alpha["embedding"], beta["embedding"], rtol=1e-2, atol=1e-2 + ), "Embeddings do not match within tolerance" + + logger.info(f"✅ test batch represent function with string input for model {model_name} done") + + +@pytest.mark.parametrize( + "model_name", + [ + "VGG-Face", + "Facenet", + "SFace", + ], +) +def test_batched_represent_for_numpy_input(model_name): + img_paths = [ + "dataset/img1.jpg", + "dataset/img2.jpg", + "dataset/img3.jpg", + "dataset/img4.jpg", + "dataset/img5.jpg", + "dataset/couple.jpg", ] - for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs): - assert np.allclose( - embedding_obj_one_by_one["embedding"], - embedding_obj["embedding"], - rtol=1e-2, - atol=1e-2 - ), "Embeddings do not match within tolerance" + expected_faces = [1, 1, 1, 1, 1, 2] - logger.info(f"✅ test batch represent function for model {model_name} done") + imgs = [] + for img_path in img_paths: + img = cv2.imread(img_path) + img = cv2.resize(img, (1000, 1000)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # print(img.shape) + imgs.append(img) + + imgs = np.array(imgs) + assert imgs.ndim == 4 and imgs.shape[0] == len(img_paths) + + batched_embedding_objs = DeepFace.represent(img_path=imgs, model_name=model_name) + + # type should be list of list of dict for batch input + assert isinstance(batched_embedding_objs, list) + for idx, batched_embedding_obj in enumerate(batched_embedding_objs): + assert isinstance(batched_embedding_obj, list) + # it also has to have the expected number of faces + assert len(batched_embedding_obj) == expected_faces[idx] + for embedding_obj in batched_embedding_obj: + assert isinstance(embedding_obj, dict) + + # we should have the same number of embeddings as the number of images + assert len(batched_embedding_objs) == len(img_paths) + + logger.info(f"✅ test batch represent function with numpy input for model {model_name} done")