mirror of
https://github.com/serengil/deepface.git
synced 2025-06-05 19:15:23 +00:00
Merge branch 'master' of https://github.com/Raghucharan16/deepface
This commit is contained in:
commit
41ae9bbcf3
@ -206,9 +206,9 @@ def analyze(
|
|||||||
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
|
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
|
(List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
|
||||||
explained below.
|
explained below.
|
||||||
|
|
||||||
(List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
|
(List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
|
||||||
the analysis results for a detected face. Each dictionary in the list contains the
|
the analysis results for a detected face. Each dictionary in the list contains the
|
||||||
following keys:
|
following keys:
|
||||||
@ -385,12 +385,12 @@ def represent(
|
|||||||
normalization: str = "base",
|
normalization: str = "base",
|
||||||
anti_spoofing: bool = False,
|
anti_spoofing: bool = False,
|
||||||
max_faces: Optional[int] = None,
|
max_faces: Optional[int] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
|
||||||
"""
|
"""
|
||||||
Represent facial images as multi-dimensional vector embeddings.
|
Represent facial images as multi-dimensional vector embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
|
img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
|
||||||
The exact path to the image, a numpy array
|
The exact path to the image, a numpy array
|
||||||
in BGR format, a file object that supports at least `.read` and is opened in binary
|
in BGR format, a file object that supports at least `.read` and is opened in binary
|
||||||
mode, or a base64 encoded image. If the source image contains multiple faces,
|
mode, or a base64 encoded image. If the source image contains multiple faces,
|
||||||
@ -423,8 +423,9 @@ def represent(
|
|||||||
max_faces (int): Set a limit on the number of faces to be processed (default is None).
|
max_faces (int): Set a limit on the number of faces to be processed (default is None).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
results (List[Dict[str, Any]]): A list of dictionaries, each containing the
|
results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries.
|
||||||
following fields:
|
Result type becomes List of List of Dict if batch input passed.
|
||||||
|
Each containing the following fields:
|
||||||
|
|
||||||
- embedding (List[float]): Multidimensional vector representing facial features.
|
- embedding (List[float]): Multidimensional vector representing facial features.
|
||||||
The number of dimensions varies based on the reference model
|
The number of dimensions varies based on the reference model
|
||||||
|
@ -24,7 +24,7 @@ class Demography(ABC):
|
|||||||
def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:
|
def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Predict for single image or batched images.
|
Predict for single image or batched images.
|
||||||
This method uses legacy method while receiving single image as input.
|
This method uses legacy method while receiving single image as input.
|
||||||
And switch to batch prediction if receives batched images.
|
And switch to batch prediction if receives batched images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -35,11 +35,11 @@ class Demography(ABC):
|
|||||||
with x = image width, y = image height and c = channel
|
with x = image width, y = image height and c = channel
|
||||||
The channel dimension will be 1 if input is grayscale. (For emotion model)
|
The channel dimension will be 1 if input is grayscale. (For emotion model)
|
||||||
"""
|
"""
|
||||||
if not self.model_name: # Check if called from derived class
|
if not self.model_name: # Check if called from derived class
|
||||||
raise NotImplementedError("no model selected")
|
raise NotImplementedError("no model selected")
|
||||||
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
|
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
|
||||||
|
|
||||||
if img_batch.shape[0] == 1: # Single image
|
if img_batch.shape[0] == 1: # Single image
|
||||||
# Predict with legacy method.
|
# Predict with legacy method.
|
||||||
return self.model(img_batch, training=False).numpy()[0, :]
|
return self.model(img_batch, training=False).numpy()[0, :]
|
||||||
|
|
||||||
@ -48,10 +48,8 @@ class Demography(ABC):
|
|||||||
return self.model.predict_on_batch(img_batch)
|
return self.model.predict_on_batch(img_batch)
|
||||||
|
|
||||||
def _preprocess_batch_or_single_input(
|
def _preprocess_batch_or_single_input(
|
||||||
self,
|
self, img: Union[np.ndarray, List[np.ndarray]]
|
||||||
img: Union[np.ndarray, List[np.ndarray]]
|
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Preprocess single or batch of images, return as 4-D numpy array.
|
Preprocess single or batch of images, return as 4-D numpy array.
|
||||||
Args:
|
Args:
|
||||||
|
@ -13,7 +13,6 @@ from deepface.commons.logger import Logger
|
|||||||
|
|
||||||
logger = Logger()
|
logger = Logger()
|
||||||
|
|
||||||
# ----------------------------------------
|
|
||||||
# dependency configurations
|
# dependency configurations
|
||||||
|
|
||||||
tf_version = package_utils.get_tf_major_version()
|
tf_version = package_utils.get_tf_major_version()
|
||||||
@ -25,12 +24,11 @@ else:
|
|||||||
from tensorflow.keras.models import Model, Sequential
|
from tensorflow.keras.models import Model, Sequential
|
||||||
from tensorflow.keras.layers import Convolution2D, Flatten, Activation
|
from tensorflow.keras.layers import Convolution2D, Flatten, Activation
|
||||||
|
|
||||||
# ----------------------------------------
|
|
||||||
|
|
||||||
WEIGHTS_URL = (
|
WEIGHTS_URL = (
|
||||||
"https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5"
|
"https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
class ApparentAgeClient(Demography):
|
class ApparentAgeClient(Demography):
|
||||||
"""
|
"""
|
||||||
@ -49,7 +47,7 @@ class ApparentAgeClient(Demography):
|
|||||||
List of images as List[np.ndarray] or
|
List of images as List[np.ndarray] or
|
||||||
Batch of images as np.ndarray (n, 224, 224, 3)
|
Batch of images as np.ndarray (n, 224, 224, 3)
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray (age_classes,) if single image,
|
np.ndarray (age_classes,) if single image,
|
||||||
np.ndarray (n, age_classes) if batched images.
|
np.ndarray (n, age_classes) if batched images.
|
||||||
"""
|
"""
|
||||||
# Preprocessing input image or image list.
|
# Preprocessing input image or image list.
|
||||||
@ -59,11 +57,10 @@ class ApparentAgeClient(Demography):
|
|||||||
age_predictions = self._predict_internal(imgs)
|
age_predictions = self._predict_internal(imgs)
|
||||||
|
|
||||||
# Calculate apparent ages
|
# Calculate apparent ages
|
||||||
if len(age_predictions.shape) == 1: # Single prediction list
|
if len(age_predictions.shape) == 1: # Single prediction list
|
||||||
return find_apparent_age(age_predictions)
|
return find_apparent_age(age_predictions)
|
||||||
|
|
||||||
return np.array([
|
return np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions])
|
||||||
find_apparent_age(age_prediction) for age_prediction in age_predictions])
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
@ -100,6 +97,7 @@ def load_model(
|
|||||||
|
|
||||||
return age_model
|
return age_model
|
||||||
|
|
||||||
|
|
||||||
def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
|
def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
|
||||||
"""
|
"""
|
||||||
Find apparent age prediction from a given probas of ages
|
Find apparent age prediction from a given probas of ages
|
||||||
@ -108,7 +106,9 @@ def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
|
|||||||
Returns:
|
Returns:
|
||||||
apparent_age (float)
|
apparent_age (float)
|
||||||
"""
|
"""
|
||||||
assert len(age_predictions.shape) == 1, f"Input should be a list of predictions, \
|
assert (
|
||||||
|
len(age_predictions.shape) == 1
|
||||||
|
), f"Input should be a list of predictions, \
|
||||||
not batched. Got shape: {age_predictions.shape}"
|
not batched. Got shape: {age_predictions.shape}"
|
||||||
output_indexes = np.arange(0, 101)
|
output_indexes = np.arange(0, 101)
|
||||||
apparent_age = np.sum(age_predictions * output_indexes)
|
apparent_age = np.sum(age_predictions * output_indexes)
|
||||||
|
@ -123,7 +123,6 @@ def analyze(
|
|||||||
batch_resp_obj.append(resp_obj)
|
batch_resp_obj.append(resp_obj)
|
||||||
return batch_resp_obj
|
return batch_resp_obj
|
||||||
|
|
||||||
|
|
||||||
# if actions is passed as tuple with single item, interestingly it becomes str here
|
# if actions is passed as tuple with single item, interestingly it becomes str here
|
||||||
if isinstance(actions, str):
|
if isinstance(actions, str):
|
||||||
actions = (actions,)
|
actions = (actions,)
|
||||||
|
@ -398,6 +398,7 @@ def __find_bulk_embeddings(
|
|||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
align=align,
|
align=align,
|
||||||
expand_percentage=expand_percentage,
|
expand_percentage=expand_percentage,
|
||||||
|
color_face='bgr' # `represent` expects images in bgr format.
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
|
@ -20,12 +20,12 @@ def represent(
|
|||||||
normalization: str = "base",
|
normalization: str = "base",
|
||||||
anti_spoofing: bool = False,
|
anti_spoofing: bool = False,
|
||||||
max_faces: Optional[int] = None,
|
max_faces: Optional[int] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
|
||||||
"""
|
"""
|
||||||
Represent facial images as multi-dimensional vector embeddings.
|
Represent facial images as multi-dimensional vector embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
|
img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
|
||||||
The exact path to the image, a numpy array in BGR format,
|
The exact path to the image, a numpy array in BGR format,
|
||||||
a base64 encoded image, or a sequence of these.
|
a base64 encoded image, or a sequence of these.
|
||||||
If the source image contains multiple faces,
|
If the source image contains multiple faces,
|
||||||
@ -53,8 +53,9 @@ def represent(
|
|||||||
max_faces (int): Set a limit on the number of faces to be processed (default is None).
|
max_faces (int): Set a limit on the number of faces to be processed (default is None).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
results (List[Dict[str, Any]]): A list of dictionaries, each containing the
|
results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries.
|
||||||
following fields:
|
Result type becomes List of List of Dict if batch input passed.
|
||||||
|
Each containing the following fields:
|
||||||
|
|
||||||
- embedding (List[float]): Multidimensional vector representing facial features.
|
- embedding (List[float]): Multidimensional vector representing facial features.
|
||||||
The number of dimensions varies based on the reference model
|
The number of dimensions varies based on the reference model
|
||||||
@ -80,16 +81,13 @@ def represent(
|
|||||||
else:
|
else:
|
||||||
images = [img_path]
|
images = [img_path]
|
||||||
|
|
||||||
batch_images = []
|
batch_images, batch_regions, batch_confidences, batch_indexes = [], [], [], []
|
||||||
batch_regions = []
|
|
||||||
batch_confidences = []
|
|
||||||
|
|
||||||
for single_img_path in images:
|
for idx, single_img_path in enumerate(images):
|
||||||
# ---------------------------------
|
# we have run pre-process in verification. so, skip if it is coming from verify.
|
||||||
# we have run pre-process in verification.
|
|
||||||
# so, this can be skipped if it is coming from verify.
|
|
||||||
target_size = model.input_shape
|
target_size = model.input_shape
|
||||||
if detector_backend != "skip":
|
if detector_backend != "skip":
|
||||||
|
# Images are returned in RGB format.
|
||||||
img_objs = detection.extract_faces(
|
img_objs = detection.extract_faces(
|
||||||
img_path=single_img_path,
|
img_path=single_img_path,
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
@ -107,6 +105,9 @@ def represent(
|
|||||||
if len(img.shape) != 3:
|
if len(img.shape) != 3:
|
||||||
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
|
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
|
||||||
|
|
||||||
|
# Convert to RGB format to keep compatability with `extract_faces`.
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
|
||||||
# make dummy region and confidence to keep compatibility with `extract_faces`
|
# make dummy region and confidence to keep compatibility with `extract_faces`
|
||||||
img_objs = [
|
img_objs = [
|
||||||
{
|
{
|
||||||
@ -130,9 +131,10 @@ def represent(
|
|||||||
for img_obj in img_objs:
|
for img_obj in img_objs:
|
||||||
if anti_spoofing is True and img_obj.get("is_real", True) is False:
|
if anti_spoofing is True and img_obj.get("is_real", True) is False:
|
||||||
raise ValueError("Spoof detected in the given image.")
|
raise ValueError("Spoof detected in the given image.")
|
||||||
|
|
||||||
img = img_obj["face"]
|
img = img_obj["face"]
|
||||||
|
|
||||||
# bgr to rgb
|
# rgb to bgr
|
||||||
img = img[:, :, ::-1]
|
img = img[:, :, ::-1]
|
||||||
|
|
||||||
region = img_obj["facial_area"]
|
region = img_obj["facial_area"]
|
||||||
@ -151,22 +153,25 @@ def represent(
|
|||||||
batch_images.append(img)
|
batch_images.append(img)
|
||||||
batch_regions.append(region)
|
batch_regions.append(region)
|
||||||
batch_confidences.append(confidence)
|
batch_confidences.append(confidence)
|
||||||
|
batch_indexes.append(idx)
|
||||||
|
|
||||||
# Convert list of images to a numpy array for batch processing
|
# Convert list of images to a numpy array for batch processing
|
||||||
batch_images = np.concatenate(batch_images, axis=0)
|
batch_images = np.concatenate(batch_images, axis=0)
|
||||||
|
|
||||||
# Forward pass through the model for the entire batch
|
# Forward pass through the model for the entire batch
|
||||||
embeddings = model.forward(batch_images)
|
embeddings = model.forward(batch_images)
|
||||||
if len(batch_images) == 1:
|
|
||||||
embeddings = [embeddings]
|
|
||||||
|
|
||||||
for embedding, region, confidence in zip(embeddings, batch_regions, batch_confidences):
|
for idx in range(0, len(images)):
|
||||||
resp_objs.append(
|
resp_obj = []
|
||||||
{
|
for idy, batch_index in enumerate(batch_indexes):
|
||||||
"embedding": embedding,
|
if idx == batch_index:
|
||||||
"facial_area": region,
|
resp_obj.append(
|
||||||
"face_confidence": confidence,
|
{
|
||||||
}
|
"embedding": embeddings if len(batch_images) == 1 else embeddings[idy],
|
||||||
)
|
"facial_area": batch_regions[idy],
|
||||||
|
"face_confidence": batch_confidences[idy],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
resp_objs.append(resp_obj)
|
||||||
|
|
||||||
return resp_objs
|
return resp_objs[0] if len(images) == 1 else resp_objs
|
||||||
|
@ -144,26 +144,39 @@ def test_analyze_for_different_detectors():
|
|||||||
else:
|
else:
|
||||||
assert result["gender"]["Man"] < result["gender"]["Woman"]
|
assert result["gender"]["Man"] < result["gender"]["Woman"]
|
||||||
|
|
||||||
def test_analyze_for_batched_image():
|
|
||||||
img = "dataset/img4.jpg"
|
def test_analyze_for_numpy_batched_image():
|
||||||
|
img1_path = "dataset/img4.jpg"
|
||||||
|
img2_path = "dataset/couple.jpg"
|
||||||
|
|
||||||
# Copy and combine the same image to create multiple faces
|
# Copy and combine the same image to create multiple faces
|
||||||
img = cv2.imread(img)
|
img1 = cv2.imread(img1_path)
|
||||||
img = np.stack([img, img])
|
img2 = cv2.imread(img2_path)
|
||||||
assert len(img.shape) == 4 # Check dimension.
|
|
||||||
assert img.shape[0] == 2 # Check batch size.
|
expected_num_faces = [1, 2]
|
||||||
|
|
||||||
|
img1 = cv2.resize(img1, (500, 500))
|
||||||
|
img2 = cv2.resize(img2, (500, 500))
|
||||||
|
|
||||||
|
img = np.stack([img1, img2])
|
||||||
|
assert len(img.shape) == 4 # Check dimension.
|
||||||
|
assert img.shape[0] == 2 # Check batch size.
|
||||||
|
|
||||||
demography_batch = DeepFace.analyze(img, silent=True)
|
demography_batch = DeepFace.analyze(img, silent=True)
|
||||||
# 2 image in batch, so 2 demography objects.
|
# 2 image in batch, so 2 demography objects.
|
||||||
assert len(demography_batch) == 2
|
assert len(demography_batch) == 2
|
||||||
|
|
||||||
for demography_objs in demography_batch:
|
for i, demography_objs in enumerate(demography_batch):
|
||||||
assert len(demography_objs) == 1 # 1 face in each image
|
|
||||||
for demography in demography_objs: # Iterate over faces
|
assert len(demography_objs) == expected_num_faces[i]
|
||||||
assert type(demography) == dict # Check type
|
for demography in demography_objs: # Iterate over faces
|
||||||
|
assert isinstance(demography, dict) # Check type
|
||||||
assert demography["age"] > 20 and demography["age"] < 40
|
assert demography["age"] > 20 and demography["age"] < 40
|
||||||
assert demography["dominant_gender"] == "Woman"
|
assert demography["dominant_gender"] in ["Woman", "Man"]
|
||||||
|
|
||||||
logger.info("✅ test analyze for multiple faces done")
|
logger.info("✅ test analyze for multiple faces done")
|
||||||
|
|
||||||
|
|
||||||
def test_batch_detect_age_for_multiple_faces():
|
def test_batch_detect_age_for_multiple_faces():
|
||||||
# Load test image and resize to model input size
|
# Load test image and resize to model input size
|
||||||
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
||||||
@ -176,6 +189,7 @@ def test_batch_detect_age_for_multiple_faces():
|
|||||||
assert np.array_equal(int(results[0]), int(results[1]))
|
assert np.array_equal(int(results[0]), int(results[1]))
|
||||||
logger.info("✅ test batch detect age for multiple faces done")
|
logger.info("✅ test batch detect age for multiple faces done")
|
||||||
|
|
||||||
|
|
||||||
def test_batch_detect_emotion_for_multiple_faces():
|
def test_batch_detect_emotion_for_multiple_faces():
|
||||||
# Load test image and resize to model input size
|
# Load test image and resize to model input size
|
||||||
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
||||||
@ -187,6 +201,7 @@ def test_batch_detect_emotion_for_multiple_faces():
|
|||||||
assert np.array_equal(results[0], results[1])
|
assert np.array_equal(results[0], results[1])
|
||||||
logger.info("✅ test batch detect emotion for multiple faces done")
|
logger.info("✅ test batch detect emotion for multiple faces done")
|
||||||
|
|
||||||
|
|
||||||
def test_batch_detect_gender_for_multiple_faces():
|
def test_batch_detect_gender_for_multiple_faces():
|
||||||
# Load test image and resize to model input size
|
# Load test image and resize to model input size
|
||||||
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
||||||
@ -198,6 +213,7 @@ def test_batch_detect_gender_for_multiple_faces():
|
|||||||
assert np.array_equal(results[0], results[1])
|
assert np.array_equal(results[0], results[1])
|
||||||
logger.info("✅ test batch detect gender for multiple faces done")
|
logger.info("✅ test batch detect gender for multiple faces done")
|
||||||
|
|
||||||
|
|
||||||
def test_batch_detect_race_for_multiple_faces():
|
def test_batch_detect_race_for_multiple_faces():
|
||||||
# Load test image and resize to model input size
|
# Load test image and resize to model input size
|
||||||
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
|
||||||
@ -207,4 +223,4 @@ def test_batch_detect_race_for_multiple_faces():
|
|||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
# Check two races are the same
|
# Check two races are the same
|
||||||
assert np.array_equal(results[0], results[1])
|
assert np.array_equal(results[0], results[1])
|
||||||
logger.info("✅ test batch detect race for multiple faces done")
|
logger.info("✅ test batch detect race for multiple faces done")
|
||||||
|
@ -15,7 +15,12 @@ logger = Logger()
|
|||||||
def test_standard_represent():
|
def test_standard_represent():
|
||||||
img_path = "dataset/img1.jpg"
|
img_path = "dataset/img1.jpg"
|
||||||
embedding_objs = DeepFace.represent(img_path)
|
embedding_objs = DeepFace.represent(img_path)
|
||||||
|
# type should be list of dict
|
||||||
|
assert isinstance(embedding_objs, list)
|
||||||
|
|
||||||
for embedding_obj in embedding_objs:
|
for embedding_obj in embedding_objs:
|
||||||
|
assert isinstance(embedding_obj, dict)
|
||||||
|
|
||||||
embedding = embedding_obj["embedding"]
|
embedding = embedding_obj["embedding"]
|
||||||
logger.debug(f"Function returned {len(embedding)} dimensional vector")
|
logger.debug(f"Function returned {len(embedding)} dimensional vector")
|
||||||
assert len(embedding) == 4096
|
assert len(embedding) == 4096
|
||||||
@ -25,18 +30,18 @@ def test_standard_represent():
|
|||||||
def test_standard_represent_with_io_object():
|
def test_standard_represent_with_io_object():
|
||||||
img_path = "dataset/img1.jpg"
|
img_path = "dataset/img1.jpg"
|
||||||
default_embedding_objs = DeepFace.represent(img_path)
|
default_embedding_objs = DeepFace.represent(img_path)
|
||||||
io_embedding_objs = DeepFace.represent(open(img_path, 'rb'))
|
io_embedding_objs = DeepFace.represent(open(img_path, "rb"))
|
||||||
assert default_embedding_objs == io_embedding_objs
|
assert default_embedding_objs == io_embedding_objs
|
||||||
|
|
||||||
# Confirm non-seekable io objects are handled properly
|
# Confirm non-seekable io objects are handled properly
|
||||||
io_obj = io.BytesIO(open(img_path, 'rb').read())
|
io_obj = io.BytesIO(open(img_path, "rb").read())
|
||||||
io_obj.seek = None
|
io_obj.seek = None
|
||||||
no_seek_io_embedding_objs = DeepFace.represent(io_obj)
|
no_seek_io_embedding_objs = DeepFace.represent(io_obj)
|
||||||
assert default_embedding_objs == no_seek_io_embedding_objs
|
assert default_embedding_objs == no_seek_io_embedding_objs
|
||||||
|
|
||||||
# Confirm non-image io objects raise exceptions
|
# Confirm non-image io objects raise exceptions
|
||||||
with pytest.raises(ValueError, match='Failed to decode image'):
|
with pytest.raises(ValueError, match="Failed to decode image"):
|
||||||
DeepFace.represent(io.BytesIO(open(r'../requirements.txt', 'rb').read()))
|
DeepFace.represent(io.BytesIO(open(r"../requirements.txt", "rb").read()))
|
||||||
|
|
||||||
logger.info("✅ test standard represent with io object function done")
|
logger.info("✅ test standard represent with io object function done")
|
||||||
|
|
||||||
@ -57,6 +62,27 @@ def test_represent_for_skipped_detector_backend_with_image_path():
|
|||||||
logger.info("✅ test represent function for skipped detector and image path input backend done")
|
logger.info("✅ test represent function for skipped detector and image path input backend done")
|
||||||
|
|
||||||
|
|
||||||
|
def test_represent_for_preloaded_image():
|
||||||
|
face_img = "dataset/img5.jpg"
|
||||||
|
img = cv2.imread(face_img)
|
||||||
|
img_objs = DeepFace.represent(img_path=img)
|
||||||
|
# type should be list of dict
|
||||||
|
assert isinstance(img_objs, list)
|
||||||
|
assert len(img_objs) >= 1
|
||||||
|
|
||||||
|
for img_obj in img_objs:
|
||||||
|
assert isinstance(img_obj, dict)
|
||||||
|
assert "embedding" in img_obj.keys()
|
||||||
|
assert "facial_area" in img_obj.keys()
|
||||||
|
assert isinstance(img_obj["facial_area"], dict)
|
||||||
|
assert "x" in img_obj["facial_area"].keys()
|
||||||
|
assert "y" in img_obj["facial_area"].keys()
|
||||||
|
assert "w" in img_obj["facial_area"].keys()
|
||||||
|
assert "h" in img_obj["facial_area"].keys()
|
||||||
|
assert "face_confidence" in img_obj.keys()
|
||||||
|
logger.info("✅ test represent function for skipped detector and preloaded image done")
|
||||||
|
|
||||||
|
|
||||||
def test_represent_for_skipped_detector_backend_with_preloaded_image():
|
def test_represent_for_skipped_detector_backend_with_preloaded_image():
|
||||||
face_img = "dataset/img5.jpg"
|
face_img = "dataset/img5.jpg"
|
||||||
img = cv2.imread(face_img)
|
img = cv2.imread(face_img)
|
||||||
@ -85,40 +111,127 @@ def test_max_faces():
|
|||||||
assert len(results) == max_faces
|
assert len(results) == max_faces
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", [
|
def test_represent_detector_backend():
|
||||||
"VGG-Face",
|
# Results using a detection backend.
|
||||||
"Facenet",
|
results_1 = DeepFace.represent(img_path="dataset/img1.jpg")
|
||||||
"SFace",
|
assert len(results_1) == 1
|
||||||
])
|
|
||||||
def test_batched_represent(model_name):
|
# Results performing face extraction first.
|
||||||
|
faces = DeepFace.extract_faces(img_path="dataset/img1.jpg", color_face='bgr')
|
||||||
|
assert len(faces) == 1
|
||||||
|
|
||||||
|
# Images sent into represent need to be in BGR format.
|
||||||
|
img = faces[0]['face']
|
||||||
|
results_2 = DeepFace.represent(img_path=img, detector_backend="skip")
|
||||||
|
assert len(results_2) == 1
|
||||||
|
|
||||||
|
# The embeddings should be the exact same for both cases.
|
||||||
|
embedding_1 = results_1[0]['embedding']
|
||||||
|
embedding_2 = results_2[0]['embedding']
|
||||||
|
assert embedding_1 == embedding_2
|
||||||
|
logger.info("✅ test represent function for consistent output.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[
|
||||||
|
"VGG-Face",
|
||||||
|
"Facenet",
|
||||||
|
"SFace",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batched_represent_for_list_input(model_name):
|
||||||
img_paths = [
|
img_paths = [
|
||||||
"dataset/img1.jpg",
|
"dataset/img1.jpg",
|
||||||
"dataset/img2.jpg",
|
"dataset/img2.jpg",
|
||||||
"dataset/img3.jpg",
|
"dataset/img3.jpg",
|
||||||
"dataset/img4.jpg",
|
"dataset/img4.jpg",
|
||||||
"dataset/img5.jpg",
|
"dataset/img5.jpg",
|
||||||
|
"dataset/couple.jpg",
|
||||||
]
|
]
|
||||||
|
|
||||||
embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name)
|
expected_faces = [1, 1, 1, 1, 1, 2]
|
||||||
assert len(embedding_objs) == len(img_paths), f"Expected {len(img_paths)} embeddings, got {len(embedding_objs)}"
|
|
||||||
|
|
||||||
if model_name == "VGG-Face":
|
batched_embedding_objs = DeepFace.represent(img_path=img_paths, model_name=model_name)
|
||||||
|
|
||||||
|
# type should be list of list of dict for batch input
|
||||||
|
assert isinstance(batched_embedding_objs, list)
|
||||||
|
|
||||||
|
assert len(batched_embedding_objs) == len(
|
||||||
|
img_paths
|
||||||
|
), f"Expected {len(img_paths)} embeddings, got {len(batched_embedding_objs)}"
|
||||||
|
|
||||||
|
# the last one has two faces
|
||||||
|
for idx, embedding_objs in enumerate(batched_embedding_objs):
|
||||||
|
# type should be list of list of dict for batch input
|
||||||
|
# batched_embedding_objs was list already, embedding_objs should be list of dict
|
||||||
|
assert isinstance(embedding_objs, list)
|
||||||
for embedding_obj in embedding_objs:
|
for embedding_obj in embedding_objs:
|
||||||
embedding = embedding_obj["embedding"]
|
assert isinstance(embedding_obj, dict)
|
||||||
logger.debug(f"Function returned {len(embedding)} dimensional vector")
|
|
||||||
assert len(embedding) == 4096, f"Expected embedding of length 4096, got {len(embedding)}"
|
|
||||||
|
|
||||||
embedding_objs_one_by_one = [
|
assert expected_faces[idx] == len(
|
||||||
embedding_obj
|
embedding_objs
|
||||||
for img_path in img_paths
|
), f"{img_paths[idx]} has {expected_faces[idx]} faces, but got {len(embedding_objs)} embeddings!"
|
||||||
for embedding_obj in DeepFace.represent(img_path=img_path, model_name=model_name)
|
|
||||||
|
for idx, img_path in enumerate(img_paths):
|
||||||
|
single_embedding_objs = DeepFace.represent(img_path=img_path, model_name=model_name)
|
||||||
|
# type should be list of dict for single input
|
||||||
|
assert isinstance(single_embedding_objs, list)
|
||||||
|
for embedding_obj in single_embedding_objs:
|
||||||
|
assert isinstance(embedding_obj, dict)
|
||||||
|
|
||||||
|
assert len(single_embedding_objs) == len(batched_embedding_objs[idx])
|
||||||
|
|
||||||
|
for alpha, beta in zip(single_embedding_objs, batched_embedding_objs[idx]):
|
||||||
|
assert np.allclose(
|
||||||
|
alpha["embedding"], beta["embedding"], rtol=1e-2, atol=1e-2
|
||||||
|
), "Embeddings do not match within tolerance"
|
||||||
|
|
||||||
|
logger.info(f"✅ test batch represent function with string input for model {model_name} done")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[
|
||||||
|
"VGG-Face",
|
||||||
|
"Facenet",
|
||||||
|
"SFace",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batched_represent_for_numpy_input(model_name):
|
||||||
|
img_paths = [
|
||||||
|
"dataset/img1.jpg",
|
||||||
|
"dataset/img2.jpg",
|
||||||
|
"dataset/img3.jpg",
|
||||||
|
"dataset/img4.jpg",
|
||||||
|
"dataset/img5.jpg",
|
||||||
|
"dataset/couple.jpg",
|
||||||
]
|
]
|
||||||
for embedding_obj_one_by_one, embedding_obj in zip(embedding_objs_one_by_one, embedding_objs):
|
expected_faces = [1, 1, 1, 1, 1, 2]
|
||||||
assert np.allclose(
|
|
||||||
embedding_obj_one_by_one["embedding"],
|
|
||||||
embedding_obj["embedding"],
|
|
||||||
rtol=1e-2,
|
|
||||||
atol=1e-2
|
|
||||||
), "Embeddings do not match within tolerance"
|
|
||||||
|
|
||||||
logger.info(f"✅ test batch represent function for model {model_name} done")
|
imgs = []
|
||||||
|
for img_path in img_paths:
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
img = cv2.resize(img, (1000, 1000))
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
# print(img.shape)
|
||||||
|
imgs.append(img)
|
||||||
|
|
||||||
|
imgs = np.array(imgs)
|
||||||
|
assert imgs.ndim == 4 and imgs.shape[0] == len(img_paths)
|
||||||
|
|
||||||
|
batched_embedding_objs = DeepFace.represent(img_path=imgs, model_name=model_name)
|
||||||
|
|
||||||
|
# type should be list of list of dict for batch input
|
||||||
|
assert isinstance(batched_embedding_objs, list)
|
||||||
|
for idx, batched_embedding_obj in enumerate(batched_embedding_objs):
|
||||||
|
assert isinstance(batched_embedding_obj, list)
|
||||||
|
# it also has to have the expected number of faces
|
||||||
|
assert len(batched_embedding_obj) == expected_faces[idx]
|
||||||
|
for embedding_obj in batched_embedding_obj:
|
||||||
|
assert isinstance(embedding_obj, dict)
|
||||||
|
|
||||||
|
# we should have the same number of embeddings as the number of images
|
||||||
|
assert len(batched_embedding_objs) == len(img_paths)
|
||||||
|
|
||||||
|
logger.info(f"✅ test batch represent function with numpy input for model {model_name} done")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user