moving logics under modules

- all extract_faces logic moved to detection module
- load image related logics moved to preprocessing module
This commit is contained in:
Sefik Ilkin Serengil 2024-01-31 19:12:16 +00:00
parent 9fbb229b97
commit 95c55c0401
9 changed files with 316 additions and 329 deletions

View File

@ -449,6 +449,7 @@ def extract_faces(
enforce_detection=enforce_detection, enforce_detection=enforce_detection,
align=align, align=align,
grayscale=grayscale, grayscale=grayscale,
human_readable=True,
) )

View File

@ -1,42 +1,19 @@
import os import os
from typing import Union, Tuple, List
import base64
from pathlib import Path from pathlib import Path
# 3rd party dependencies # 3rd party dependencies
from PIL import Image
import requests
import numpy as np
import cv2
import tensorflow as tf import tensorflow as tf
# package dependencies # package dependencies
from deepface.detectors import DetectorWrapper
from deepface.models.Detector import DetectedFace, FacialAreaRegion
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
logger = Logger(module="commons.functions") logger = Logger(module="commons.functions")
# pylint: disable=no-else-raise
# --------------------------------------------------
# configurations of dependencies
def get_tf_major_version() -> int: def get_tf_major_version() -> int:
return int(tf.__version__.split(".", maxsplit=1)[0]) return int(tf.__version__.split(".", maxsplit=1)[0])
tf_major_version = get_tf_major_version()
if tf_major_version == 1:
from keras.preprocessing import image
elif tf_major_version == 2:
from tensorflow.keras.preprocessing import image
# --------------------------------------------------
def initialize_folder() -> None: def initialize_folder() -> None:
"""Initialize the folder for storing weights and models. """Initialize the folder for storing weights and models.
@ -65,266 +42,6 @@ def get_deepface_home() -> str:
return str(os.getenv("DEEPFACE_HOME", default=str(Path.home()))) return str(os.getenv("DEEPFACE_HOME", default=str(Path.home())))
# --------------------------------------------------
def loadBase64Img(uri: str) -> np.ndarray:
"""Load image from base64 string.
Args:
uri: a base64 string.
Returns:
numpy array: the loaded image.
"""
encoded_data = uri.split(",")[1]
nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8)
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
return img_bgr
def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
"""
Load image from path, url, base64 or numpy array.
Args:
img: a path, url, base64 or numpy array.
Returns:
image (numpy array): the loaded image in BGR format
image name (str): image name itself
"""
# The image is already a numpy array
if isinstance(img, np.ndarray):
return img, "numpy array"
if isinstance(img, Path):
img = str(img)
if not isinstance(img, str):
raise ValueError(f"img must be numpy array or str but it is {type(img)}")
# The image is a base64 string
if img.startswith("data:image/"):
return loadBase64Img(img), "base64 encoded string"
# The image is a url
if img.startswith("http"):
return (
np.array(Image.open(requests.get(img, stream=True, timeout=60).raw).convert("BGR")),
# return url as image name
img,
)
# The image is a path
if os.path.isfile(img) is not True:
raise ValueError(f"Confirm that {img} exists")
# image must be a file on the system then
# image name must have english characters
if img.isascii() is False:
raise ValueError(f"Input image must not have non-english characters - {img}")
img_obj_bgr = cv2.imread(img)
# img_obj_rgb = cv2.cvtColor(img_obj_bgr, cv2.COLOR_BGR2RGB)
return img_obj_bgr, img
# --------------------------------------------------
def extract_faces(
img: Union[str, np.ndarray],
target_size: tuple = (224, 224),
detector_backend: str = "opencv",
grayscale: bool = False,
enforce_detection: bool = True,
align: bool = True,
) -> List[Tuple[np.ndarray, dict, float]]:
"""
Extract faces from an image.
Args:
img: a path, url, base64 or numpy array.
target_size (tuple, optional): the target size of the extracted faces.
Defaults to (224, 224).
detector_backend (str, optional): the face detector backend. Defaults to "opencv".
grayscale (bool, optional): whether to convert the extracted faces to grayscale.
Defaults to False.
enforce_detection (bool, optional): whether to enforce face detection. Defaults to True.
align (bool, optional): whether to align the extracted faces. Defaults to True.
Raises:
ValueError: if face could not be detected and enforce_detection is True.
Returns:
results (List[Tuple[np.ndarray, dict, float]]): A list of tuples
where each tuple contains:
- detected_face (np.ndarray): The detected face as a NumPy array.
- face_region (dict): The image region represented as
{"x": x, "y": y, "w": w, "h": h}
- confidence (float): The confidence score associated with the detected face.
"""
# this is going to store a list of img itself (numpy), it region and confidence
extracted_faces = []
# img might be path, base64 or numpy array. Convert it to numpy whatever it is.
img, img_name = load_image(img)
base_region = FacialAreaRegion(x=0, y=0, w=img.shape[1], h=img.shape[0])
if detector_backend == "skip":
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
else:
face_objs = DetectorWrapper.detect_faces(detector_backend, img, align)
# in case of no face found
if len(face_objs) == 0 and enforce_detection is True:
if img_name is not None:
raise ValueError(
f"Face could not be detected in {img_name}."
"Please confirm that the picture is a face photo "
"or consider to set enforce_detection param to False."
)
else:
raise ValueError(
"Face could not be detected. Please confirm that the picture is a face photo "
"or consider to set enforce_detection param to False."
)
if len(face_objs) == 0 and enforce_detection is False:
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
for face_obj in face_objs:
current_img = face_obj.img
current_region = face_obj.facial_area
confidence = face_obj.confidence
if current_img.shape[0] > 0 and current_img.shape[1] > 0:
if grayscale is True:
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
# resize and padding
factor_0 = target_size[0] / current_img.shape[0]
factor_1 = target_size[1] / current_img.shape[1]
factor = min(factor_0, factor_1)
dsize = (
int(current_img.shape[1] * factor),
int(current_img.shape[0] * factor),
)
current_img = cv2.resize(current_img, dsize)
diff_0 = target_size[0] - current_img.shape[0]
diff_1 = target_size[1] - current_img.shape[1]
if grayscale is False:
# Put the base image in the middle of the padded image
current_img = np.pad(
current_img,
(
(diff_0 // 2, diff_0 - diff_0 // 2),
(diff_1 // 2, diff_1 - diff_1 // 2),
(0, 0),
),
"constant",
)
else:
current_img = np.pad(
current_img,
(
(diff_0 // 2, diff_0 - diff_0 // 2),
(diff_1 // 2, diff_1 - diff_1 // 2),
),
"constant",
)
# double check: if target image is not still the same size with target.
if current_img.shape[0:2] != target_size:
current_img = cv2.resize(current_img, target_size)
# normalizing the image pixels
# what this line doing? must?
img_pixels = image.img_to_array(current_img)
img_pixels = np.expand_dims(img_pixels, axis=0)
img_pixels /= 255 # normalize input in [0, 1]
# int cast is for the exception - object of type 'float32' is not JSON serializable
region_obj = {
"x": current_region.x,
"y": current_region.y,
"w": current_region.w,
"h": current_region.h,
}
extracted_face = (img_pixels, region_obj, confidence)
extracted_faces.append(extracted_face)
if len(extracted_faces) == 0 and enforce_detection == True:
raise ValueError(
f"Detected face shape is {img.shape}. Consider to set enforce_detection arg to False."
)
return extracted_faces
def normalize_input(img: np.ndarray, normalization: str = "base") -> np.ndarray:
"""Normalize input image.
Args:
img (numpy array): the input image.
normalization (str, optional): the normalization technique. Defaults to "base",
for no normalization.
Returns:
numpy array: the normalized image.
"""
# issue 131 declares that some normalization techniques improves the accuracy
if normalization == "base":
return img
# @trevorgribble and @davedgd contributed this feature
# restore input in scale of [0, 255] because it was normalized in scale of
# [0, 1] in preprocess_face
img *= 255
if normalization == "raw":
pass # return just restored pixels
elif normalization == "Facenet":
mean, std = img.mean(), img.std()
img = (img - mean) / std
elif normalization == "Facenet2018":
# simply / 127.5 - 1 (similar to facenet 2018 model preprocessing step as @iamrishab posted)
img /= 127.5
img -= 1
elif normalization == "VGGFace":
# mean subtraction based on VGGFace1 training data
img[..., 0] -= 93.5940
img[..., 1] -= 104.7624
img[..., 2] -= 129.1863
elif normalization == "VGGFace2":
# mean subtraction based on VGGFace2 training data
img[..., 0] -= 91.4953
img[..., 1] -= 103.8827
img[..., 2] -= 131.0912
elif normalization == "ArcFace":
# Reference study: The faces are cropped and resized to 112×112,
# and each pixel (ranged between [0, 255]) in RGB images is normalised
# by subtracting 127.5 then divided by 128.
img -= 127.5
img /= 128
else:
raise ValueError(f"unimplemented normalization type - {normalization}")
return img
def find_target_size(model_name: str) -> tuple: def find_target_size(model_name: str) -> tuple:
"""Find the target size of the model. """Find the target size of the model.

View File

@ -6,8 +6,7 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
# project dependencies # project dependencies
from deepface.modules import modeling from deepface.modules import modeling, detection
from deepface.commons import functions
from deepface.extendedmodels import Gender, Race, Emotion from deepface.extendedmodels import Gender, Race, Emotion
@ -114,8 +113,8 @@ def analyze(
# --------------------------------- # ---------------------------------
resp_objects = [] resp_objects = []
img_objs = functions.extract_faces( img_objs = detection.extract_faces(
img=img_path, img_path=img_path,
target_size=(224, 224), target_size=(224, 224),
detector_backend=detector_backend, detector_backend=detector_backend,
grayscale=False, grayscale=False,
@ -123,7 +122,10 @@ def analyze(
align=align, align=align,
) )
for img_content, img_region, img_confidence in img_objs: for img_obj in img_objs:
img_content = img_obj["face"]
img_region = img_obj["facial_area"]
img_confidence = img_obj["confidence"]
if img_content.shape[0] > 0 and img_content.shape[1] > 0: if img_content.shape[0] > 0 and img_content.shape[1] > 0:
obj = {} obj = {}
# facial attribute analysis # facial attribute analysis

View File

@ -3,10 +3,26 @@ from typing import Any, Dict, List, Tuple, Union
# 3rd part dependencies # 3rd part dependencies
import numpy as np import numpy as np
import cv2
from PIL import Image from PIL import Image
# project dependencies # project dependencies
from deepface.modules import preprocessing
from deepface.models.Detector import DetectedFace, FacialAreaRegion
from deepface.detectors import DetectorWrapper
from deepface.commons import functions from deepface.commons import functions
from deepface.commons.logger import Logger
logger = Logger(module="deepface/modules/detection.py")
# pylint: disable=no-else-raise
tf_major_version = functions.get_tf_major_version()
if tf_major_version == 1:
from keras.preprocessing import image
elif tf_major_version == 2:
from tensorflow.keras.preprocessing import image
def extract_faces( def extract_faces(
@ -16,6 +32,7 @@ def extract_faces(
enforce_detection: bool = True, enforce_detection: bool = True,
align: bool = True, align: bool = True,
grayscale: bool = False, grayscale: bool = False,
human_readable=False,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Extract faces from a given image Extract faces from a given image
@ -38,6 +55,8 @@ def extract_faces(
grayscale (boolean): Flag to convert the image to grayscale before grayscale (boolean): Flag to convert the image to grayscale before
processing (default is False). processing (default is False).
human_readable (bool): Flag to make the image human readable. 3D RGB for human readable
or 4D BGR for ML models (default is False).
Returns: Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains: results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
@ -48,27 +67,108 @@ def extract_faces(
resp_objs = [] resp_objs = []
img_objs = functions.extract_faces( # img might be path, base64 or numpy array. Convert it to numpy whatever it is.
img=img_path, img, img_name = preprocessing.load_image(img_path)
target_size=target_size,
detector_backend=detector_backend,
grayscale=grayscale,
enforce_detection=enforce_detection,
align=align,
)
for img, region, confidence in img_objs: base_region = FacialAreaRegion(x=0, y=0, w=img.shape[1], h=img.shape[0])
resp_obj = {}
if detector_backend == "skip":
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
else:
face_objs = DetectorWrapper.detect_faces(detector_backend, img, align)
# in case of no face found
if len(face_objs) == 0 and enforce_detection is True:
if img_name is not None:
raise ValueError(
f"Face could not be detected in {img_name}."
"Please confirm that the picture is a face photo "
"or consider to set enforce_detection param to False."
)
else:
raise ValueError(
"Face could not be detected. Please confirm that the picture is a face photo "
"or consider to set enforce_detection param to False."
)
if len(face_objs) == 0 and enforce_detection is False:
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
for face_obj in face_objs:
current_img = face_obj.img
current_region = face_obj.facial_area
confidence = face_obj.confidence
if current_img.shape[0] == 0 or current_img.shape[1] == 0:
continue
if grayscale is True:
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
# resize and padding
factor_0 = target_size[0] / current_img.shape[0]
factor_1 = target_size[1] / current_img.shape[1]
factor = min(factor_0, factor_1)
dsize = (
int(current_img.shape[1] * factor),
int(current_img.shape[0] * factor),
)
current_img = cv2.resize(current_img, dsize)
diff_0 = target_size[0] - current_img.shape[0]
diff_1 = target_size[1] - current_img.shape[1]
if grayscale is False:
# Put the base image in the middle of the padded image
current_img = np.pad(
current_img,
(
(diff_0 // 2, diff_0 - diff_0 // 2),
(diff_1 // 2, diff_1 - diff_1 // 2),
(0, 0),
),
"constant",
)
else:
current_img = np.pad(
current_img,
(
(diff_0 // 2, diff_0 - diff_0 // 2),
(diff_1 // 2, diff_1 - diff_1 // 2),
),
"constant",
)
# double check: if target image is not still the same size with target.
if current_img.shape[0:2] != target_size:
current_img = cv2.resize(current_img, target_size)
# normalizing the image pixels
# what this line doing? must?
img_pixels = image.img_to_array(current_img)
img_pixels = np.expand_dims(img_pixels, axis=0)
img_pixels /= 255 # normalize input in [0, 1]
# discard expanded dimension # discard expanded dimension
if len(img.shape) == 4: if human_readable is True and len(img_pixels.shape) == 4:
img = img[0] img_pixels = img_pixels[0]
# bgr to rgb resp_objs.append(
resp_obj["face"] = img[:, :, ::-1] {
resp_obj["facial_area"] = region "face": img_pixels[:, :, ::-1] if human_readable is True else img_pixels,
resp_obj["confidence"] = confidence "facial_area": {
resp_objs.append(resp_obj) "x": current_region.x,
"y": current_region.y,
"w": current_region.w,
"h": current_region.h,
},
"confidence": confidence,
}
)
if len(resp_objs) == 0 and enforce_detection == True:
raise ValueError(
f"Detected face shape is {img.shape}. Consider to set enforce_detection arg to False."
)
return resp_objs return resp_objs

View File

@ -0,0 +1,131 @@
import os
from typing import Union, Tuple
import base64
from pathlib import Path
# 3rd party
import numpy as np
import cv2
from PIL import Image
import requests
def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
"""
Load image from path, url, base64 or numpy array.
Args:
img: a path, url, base64 or numpy array.
Returns:
image (numpy array): the loaded image in BGR format
image name (str): image name itself
"""
# The image is already a numpy array
if isinstance(img, np.ndarray):
return img, "numpy array"
if isinstance(img, Path):
img = str(img)
if not isinstance(img, str):
raise ValueError(f"img must be numpy array or str but it is {type(img)}")
# The image is a base64 string
if img.startswith("data:image/"):
return load_base64(img), "base64 encoded string"
# The image is a url
if img.startswith("http"):
return (
np.array(Image.open(requests.get(img, stream=True, timeout=60).raw).convert("BGR")),
# return url as image name
img,
)
# The image is a path
if os.path.isfile(img) is not True:
raise ValueError(f"Confirm that {img} exists")
# image must be a file on the system then
# image name must have english characters
if img.isascii() is False:
raise ValueError(f"Input image must not have non-english characters - {img}")
img_obj_bgr = cv2.imread(img)
# img_obj_rgb = cv2.cvtColor(img_obj_bgr, cv2.COLOR_BGR2RGB)
return img_obj_bgr, img
def load_base64(uri: str) -> np.ndarray:
"""Load image from base64 string.
Args:
uri: a base64 string.
Returns:
numpy array: the loaded image.
"""
encoded_data = uri.split(",")[1]
nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8)
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
return img_bgr
def normalize_input(img: np.ndarray, normalization: str = "base") -> np.ndarray:
"""Normalize input image.
Args:
img (numpy array): the input image.
normalization (str, optional): the normalization technique. Defaults to "base",
for no normalization.
Returns:
numpy array: the normalized image.
"""
# issue 131 declares that some normalization techniques improves the accuracy
if normalization == "base":
return img
# @trevorgribble and @davedgd contributed this feature
# restore input in scale of [0, 255] because it was normalized in scale of
# [0, 1] in preprocess_face
img *= 255
if normalization == "raw":
pass # return just restored pixels
elif normalization == "Facenet":
mean, std = img.mean(), img.std()
img = (img - mean) / std
elif normalization == "Facenet2018":
# simply / 127.5 - 1 (similar to facenet 2018 model preprocessing step as @iamrishab posted)
img /= 127.5
img -= 1
elif normalization == "VGGFace":
# mean subtraction based on VGGFace1 training data
img[..., 0] -= 93.5940
img[..., 1] -= 104.7624
img[..., 2] -= 129.1863
elif normalization == "VGGFace2":
# mean subtraction based on VGGFace2 training data
img[..., 0] -= 91.4953
img[..., 1] -= 103.8827
img[..., 2] -= 131.0912
elif normalization == "ArcFace":
# Reference study: The faces are cropped and resized to 112×112,
# and each pixel (ranged between [0, 255]) in RGB images is normalised
# by subtracting 127.5 then divided by 128.
img -= 127.5
img /= 128
else:
raise ValueError(f"unimplemented normalization type - {normalization}")
return img

View File

@ -12,7 +12,7 @@ from tqdm import tqdm
# project dependencies # project dependencies
from deepface.commons import functions, distance as dst from deepface.commons import functions, distance as dst
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
from deepface.modules import representation from deepface.modules import representation, detection
logger = Logger(module="deepface/modules/recognition.py") logger = Logger(module="deepface/modules/recognition.py")
@ -202,8 +202,8 @@ def find(
) )
# img path might have more than once face # img path might have more than once face
source_objs = functions.extract_faces( source_objs = detection.extract_faces(
img=img_path, img_path=img_path,
target_size=target_size, target_size=target_size,
detector_backend=detector_backend, detector_backend=detector_backend,
grayscale=False, grayscale=False,
@ -213,7 +213,9 @@ def find(
resp_obj = [] resp_obj = []
for source_img, source_region, _ in source_objs: for source_obj in source_objs:
source_img = source_obj["face"]
source_region = source_obj["facial_area"]
target_embedding_obj = representation.represent( target_embedding_obj = representation.represent(
img_path=source_img, img_path=source_img,
model_name=model_name, model_name=model_name,
@ -333,8 +335,8 @@ def __find_bulk_embeddings(
desc="Finding representations", desc="Finding representations",
disable=silent, disable=silent,
): ):
img_objs = functions.extract_faces( img_objs = detection.extract_faces(
img=employee, img_path=employee,
target_size=target_size, target_size=target_size,
detector_backend=detector_backend, detector_backend=detector_backend,
grayscale=False, grayscale=False,
@ -342,7 +344,9 @@ def __find_bulk_embeddings(
align=align, align=align,
) )
for img_content, img_region, _ in img_objs: for img_obj in img_objs:
img_content = img_obj["face"]
img_region = img_obj["facial_area"]
embedding_obj = representation.represent( embedding_obj = representation.represent(
img_path=img_content, img_path=img_content,
model_name=model_name, model_name=model_name,

View File

@ -6,7 +6,7 @@ import numpy as np
import cv2 import cv2
# project dependencies # project dependencies
from deepface.modules import modeling from deepface.modules import modeling, detection, preprocessing
from deepface.commons import functions from deepface.commons import functions
from deepface.models.FacialRecognition import FacialRecognition from deepface.models.FacialRecognition import FacialRecognition
@ -63,8 +63,8 @@ def represent(
# we have run pre-process in verification. so, this can be skipped if it is coming from verify. # we have run pre-process in verification. so, this can be skipped if it is coming from verify.
target_size = functions.find_target_size(model_name=model_name) target_size = functions.find_target_size(model_name=model_name)
if detector_backend != "skip": if detector_backend != "skip":
img_objs = functions.extract_faces( img_objs = detection.extract_faces(
img=img_path, img_path=img_path,
target_size=(target_size[1], target_size[0]), target_size=(target_size[1], target_size[0]),
detector_backend=detector_backend, detector_backend=detector_backend,
grayscale=False, grayscale=False,
@ -73,7 +73,7 @@ def represent(
) )
else: # skip else: # skip
# Try load. If load error, will raise exception internal # Try load. If load error, will raise exception internal
img, _ = functions.load_image(img_path) img, _ = preprocessing.load_image(img_path)
# -------------------------------- # --------------------------------
if len(img.shape) == 4: if len(img.shape) == 4:
img = img[0] # e.g. (1, 224, 224, 3) to (224, 224, 3) img = img[0] # e.g. (1, 224, 224, 3) to (224, 224, 3)
@ -85,13 +85,21 @@ def represent(
img = (img.astype(np.float32) / 255.0).astype(np.float32) img = (img.astype(np.float32) / 255.0).astype(np.float32)
# -------------------------------- # --------------------------------
# make dummy region and confidence to keep compatibility with `extract_faces` # make dummy region and confidence to keep compatibility with `extract_faces`
img_region = {"x": 0, "y": 0, "w": img.shape[1], "h": img.shape[2]} img_objs = [
img_objs = [(img, img_region, 0)] {
"face": img,
"facial_area": {"x": 0, "y": 0, "w": img.shape[1], "h": img.shape[2]},
"confidence": 0,
}
]
# --------------------------------- # ---------------------------------
for img, region, confidence in img_objs: for img_obj in img_objs:
img = img_obj["face"]
region = img_obj["facial_area"]
confidence = img_obj["confidence"]
# custom normalization # custom normalization
img = functions.normalize_input(img=img, normalization=normalization) img = preprocessing.normalize_input(img=img, normalization=normalization)
embedding = model.find_embeddings(img) embedding = model.find_embeddings(img)

View File

@ -7,7 +7,7 @@ import numpy as np
# project dependencies # project dependencies
from deepface.commons import functions, distance as dst from deepface.commons import functions, distance as dst
from deepface.modules import representation from deepface.modules import representation, detection
def verify( def verify(
@ -82,8 +82,8 @@ def verify(
target_size = functions.find_target_size(model_name=model_name) target_size = functions.find_target_size(model_name=model_name)
# img pairs might have many faces # img pairs might have many faces
img1_objs = functions.extract_faces( img1_objs = detection.extract_faces(
img=img1_path, img_path=img1_path,
target_size=target_size, target_size=target_size,
detector_backend=detector_backend, detector_backend=detector_backend,
grayscale=False, grayscale=False,
@ -91,8 +91,8 @@ def verify(
align=align, align=align,
) )
img2_objs = functions.extract_faces( img2_objs = detection.extract_faces(
img=img2_path, img_path=img2_path,
target_size=target_size, target_size=target_size,
detector_backend=detector_backend, detector_backend=detector_backend,
grayscale=False, grayscale=False,
@ -103,8 +103,12 @@ def verify(
distances = [] distances = []
regions = [] regions = []
# now we will find the face pair with minimum distance # now we will find the face pair with minimum distance
for img1_content, img1_region, _ in img1_objs: for img1_obj in img1_objs:
for img2_content, img2_region, _ in img2_objs: img1_content = img1_obj["face"]
img1_region = img1_obj["facial_area"]
for img2_obj in img2_objs:
img2_content = img2_obj["face"]
img2_region = img2_obj["facial_area"]
img1_embedding_obj = representation.represent( img1_embedding_obj = representation.represent(
img_path=img1_content, img_path=img1_content,
model_name=model_name, model_name=model_name,

View File

@ -1,12 +1,14 @@
import numpy as np
import pytest
from deepface import DeepFace from deepface import DeepFace
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
logger = Logger("tests/test_extract_faces.py") logger = Logger("tests/test_extract_faces.py")
detectors = ["opencv", "mtcnn"]
def test_different_detectors(): def test_different_detectors():
detectors = ["opencv", "mtcnn"]
for detector in detectors: for detector in detectors:
img_objs = DeepFace.extract_faces(img_path="dataset/img11.jpg", detector_backend=detector) img_objs = DeepFace.extract_faces(img_path="dataset/img11.jpg", detector_backend=detector)
for img_obj in img_objs: for img_obj in img_objs:
@ -22,3 +24,21 @@ def test_different_detectors():
img = img_obj["face"] img = img_obj["face"]
assert img.shape[0] > 0 and img.shape[1] > 0 assert img.shape[0] > 0 and img.shape[1] > 0
logger.info(f"✅ extract_faces for {detector} backend test is done") logger.info(f"✅ extract_faces for {detector} backend test is done")
def test_backends_for_enforced_detection_with_non_facial_inputs():
black_img = np.zeros([224, 224, 3])
for detector in detectors:
with pytest.raises(ValueError):
_ = DeepFace.extract_faces(img_path=black_img, detector_backend=detector)
logger.info("✅ extract_faces for enforced detection and non-facial image test is done")
def test_backends_for_not_enforced_detection_with_non_facial_inputs():
black_img = np.zeros([224, 224, 3])
for detector in detectors:
objs = DeepFace.extract_faces(
img_path=black_img, detector_backend=detector, enforce_detection=False
)
assert objs[0]["face"].shape == (224, 224, 3)
logger.info("✅ extract_faces for not enforced detection and non-facial image test is done")