Merge pull request #1025 from serengil/feat-task-1602-new-detector-interface

Feat task 1602 new detector interface
This commit is contained in:
Sefik Ilkin Serengil 2024-02-16 17:50:50 +00:00 committed by GitHub
commit ee4ad41c2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 357 additions and 462 deletions

View File

@ -423,7 +423,7 @@ def stream(
def extract_faces( def extract_faces(
img_path: Union[str, np.ndarray], img_path: Union[str, np.ndarray],
target_size: Tuple[int, int] = (224, 224), target_size: Optional[Tuple[int, int]] = (224, 224),
detector_backend: str = "opencv", detector_backend: str = "opencv",
enforce_detection: bool = True, enforce_detection: bool = True,
align: bool = True, align: bool = True,

View File

@ -1,6 +1,7 @@
from typing import Any, List from typing import Any, List, Tuple
import numpy as np import numpy as np
from deepface.models.Detector import Detector, DetectedFace from deepface.modules import detection
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.detectors import ( from deepface.detectors import (
FastMtCnn, FastMtCnn,
MediaPipe, MediaPipe,
@ -80,10 +81,101 @@ def detect_faces(
- confidence (float): The confidence score associated with the detected face. - confidence (float): The confidence score associated with the detected face.
""" """
face_detector: Detector = build_model(detector_backend) face_detector: Detector = build_model(detector_backend)
# validate expand percentage score
if expand_percentage < 0: if expand_percentage < 0:
logger.warn( logger.warn(
f"Expand percentage cannot be negative but you set it to {expand_percentage}." f"Expand percentage cannot be negative but you set it to {expand_percentage}."
"Overwritten it to 0." "Overwritten it to 0."
) )
expand_percentage = 0 expand_percentage = 0
return face_detector.detect_faces(img=img, align=align, expand_percentage=expand_percentage)
# find facial areas of given image
facial_areas = face_detector.detect_faces(img=img)
results = []
for facial_area in facial_areas:
x = facial_area.x
y = facial_area.y
w = facial_area.w
h = facial_area.h
left_eye = facial_area.left_eye
right_eye = facial_area.right_eye
confidence = facial_area.confidence
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * 2 * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * 2 * expand_percentage) / 100)) # expand bottom
# extract detected face unaligned
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
# aligning detected face causes a lot of black pixels
# if align is True:
# detected_face, _ = detection.align_face(
# img=detected_face, left_eye=left_eye, right_eye=right_eye
# )
# align original image, then find projection of detected face area after alignment
if align is True: # and left_eye is not None and right_eye is not None:
aligned_img, angle = detection.align_face(
img=img, left_eye=left_eye, right_eye=right_eye
)
x1_new, y1_new, x2_new, y2_new = rotate_facial_area(
facial_area=(x2, y2, x2 + w2, y2 + h2), angle=angle, direction=1, size=img.shape
)
detected_face = aligned_img[int(y1_new) : int(y2_new), int(x1_new) : int(x2_new)]
result = DetectedFace(
img=detected_face,
facial_area=FacialAreaRegion(
x=x, y=y, h=h, w=w, confidence=confidence, left_eye=left_eye, right_eye=right_eye
),
confidence=confidence,
)
results.append(result)
return results
def rotate_facial_area(
facial_area: Tuple[int, int, int, int], angle: float, direction: int, size: Tuple[int, int]
) -> Tuple[int, int, int, int]:
"""
Rotate the facial area around its center.
Inspried from the work of @UmutDeniz26 - github.com/serengil/retinaface/pull/80
Args:
facial_area (tuple of int): Representing the (x1, y1, x2, y2) of the facial area.
x2 is equal to x1 + w1, and y2 is equal to y1 + h1
angle (float): Angle of rotation in degrees.
direction (int): Direction of rotation (-1 for clockwise, 1 for counterclockwise).
size (tuple of int): Tuple representing the size of the image (width, height).
Returns:
rotated_coordinates (tuple of int): Representing the new coordinates
(x1, y1, x2, y2) or (x1, y1, x1+w1, y1+h1) of the rotated facial area.
"""
# Angle in radians
angle = angle * np.pi / 180
# Translate the facial area to the center of the image
x = (facial_area[0] + facial_area[2]) / 2 - size[1] / 2
y = (facial_area[1] + facial_area[3]) / 2 - size[0] / 2
# Rotate the facial area
x_new = x * np.cos(angle) + y * direction * np.sin(angle)
y_new = -x * direction * np.sin(angle) + y * np.cos(angle)
# Translate the facial area back to the original position
x_new = x_new + size[1] / 2
y_new = y_new + size[0] / 2
# Calculate the new facial area
x1 = x_new - (facial_area[2] - facial_area[0]) / 2
y1 = y_new - (facial_area[3] - facial_area[1]) / 2
x2 = x_new + (facial_area[2] - facial_area[0]) / 2
y2 = y_new + (facial_area[3] - facial_area[1]) / 2
return (int(x1), int(y1), int(x2), int(y2))

View File

@ -4,7 +4,7 @@ import bz2
import gdown import gdown
import numpy as np import numpy as np
from deepface.commons import folder_utils from deepface.commons import folder_utils
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
logger = Logger(module="detectors.DlibWrapper") logger = Logger(module="detectors.DlibWrapper")
@ -56,50 +56,18 @@ class DlibClient(Detector):
detector["sp"] = sp detector["sp"] = sp
return detector return detector
def detect_faces( def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
""" """
Detect and align face with dlib Detect and align face with dlib
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
# this is not a must dependency. do not import it in the global level.
try:
import dlib
except ModuleNotFoundError as e:
raise ImportError(
"Dlib is an optional detector, ensure the library is installed."
"Please install using 'pip install dlib' "
) from e
if expand_percentage != 0:
logger.warn(
f"You set expand_percentage argument to {expand_percentage},"
"but dlib hog handles detection by itself"
)
resp = [] resp = []
sp = self.model["sp"]
detected_face = None
face_detector = self.model["face_detector"] face_detector = self.model["face_detector"]
# note that, by design, dlib's fhog face detector scores are >0 but not capped at 1 # note that, by design, dlib's fhog face detector scores are >0 but not capped at 1
@ -107,30 +75,32 @@ class DlibClient(Detector):
if len(detections) > 0: if len(detections) > 0:
for idx, d in enumerate(detections): for idx, detection in enumerate(detections):
left = d.left() left = detection.left()
right = d.right() right = detection.right()
top = d.top() top = detection.top()
bottom = d.bottom() bottom = detection.bottom()
y = int(max(0, top)) y = int(max(0, top))
h = int(min(bottom, img.shape[0]) - y) h = int(min(bottom, img.shape[0]) - y)
x = int(max(0, left)) x = int(max(0, left))
w = int(min(right, img.shape[1]) - x) w = int(min(right, img.shape[1]) - x)
detected_face = img[int(y) : int(y + h), int(x) : int(x + w)] shape = self.model["sp"](img, detection)
left_eye = (shape.part(2).x, shape.part(2).y)
right_eye = (shape.part(0).x, shape.part(0).y)
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = scores[idx] confidence = scores[idx]
if align: facial_area = FacialAreaRegion(
img_shape = sp(img, detections[idx]) x=x,
detected_face = dlib.get_face_chip(img, img_shape, size=detected_face.shape[0]) y=y,
w=w,
detected_face_obj = DetectedFace( h=h,
img=detected_face, facial_area=img_region, confidence=confidence left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
) )
resp.append(facial_area)
resp.append(detected_face_obj)
return resp return resp

View File

@ -1,8 +1,7 @@
from typing import Any, Union, List from typing import Any, Union, List
import cv2 import cv2
import numpy as np import numpy as np
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.modules import detection
# Link -> https://github.com/timesler/facenet-pytorch # Link -> https://github.com/timesler/facenet-pytorch
# Examples https://www.kaggle.com/timesler/guide-to-mtcnn-in-facenet-pytorch # Examples https://www.kaggle.com/timesler/guide-to-mtcnn-in-facenet-pytorch
@ -12,33 +11,18 @@ class FastMtCnnClient(Detector):
def __init__(self): def __init__(self):
self.model = self.build_model() self.model = self.build_model()
def detect_faces( def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
""" """
Detect and align face with mtcnn Detect and align face with mtcnn
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
resp = [] resp = []
detected_face = None
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # mtcnn expects RGB but OpenCV read BGR img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # mtcnn expects RGB but OpenCV read BGR
detections = self.model.detect( detections = self.model.detect(
img_rgb, landmarks=True img_rgb, landmarks=True
@ -47,31 +31,20 @@ class FastMtCnnClient(Detector):
for current_detection in zip(*detections): for current_detection in zip(*detections):
x, y, w, h = xyxy_to_xywh(current_detection[0]) x, y, w, h = xyxy_to_xywh(current_detection[0])
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = current_detection[1] confidence = current_detection[1]
if align:
left_eye = current_detection[2][0] left_eye = current_detection[2][0]
right_eye = current_detection[2][1] right_eye = current_detection[2][1]
detected_face = detection.align_face(
img=detected_face, left_eye=left_eye, right_eye=right_eye
)
detected_face_obj = DetectedFace( facial_area = FacialAreaRegion(
img=detected_face, facial_area=img_region, confidence=confidence x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
) )
resp.append(facial_area)
resp.append(detected_face_obj)
return resp return resp

View File

@ -1,7 +1,6 @@
from typing import Any, List from typing import Any, List
import numpy as np import numpy as np
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.modules import detection
# Link - https://google.github.io/mediapipe/solutions/face_detection # Link - https://google.github.io/mediapipe/solutions/face_detection
@ -29,28 +28,15 @@ class MediaPipeClient(Detector):
face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.7) face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.7)
return face_detection return face_detection
def detect_faces( def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
""" """
Detect and align face with mediapipe Detect and align face with mediapipe
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
resp = [] resp = []
@ -75,7 +61,6 @@ class MediaPipeClient(Detector):
y = int(bounding_box.ymin * img_height) y = int(bounding_box.ymin * img_height)
h = int(bounding_box.height * img_height) h = int(bounding_box.height * img_height)
# Extract landmarks
left_eye = (int(landmarks[0].x * img_width), int(landmarks[0].y * img_height)) left_eye = (int(landmarks[0].x * img_width), int(landmarks[0].y * img_height))
right_eye = (int(landmarks[1].x * img_width), int(landmarks[1].y * img_height)) right_eye = (int(landmarks[1].x * img_width), int(landmarks[1].y * img_height))
# nose = (int(landmarks[2].x * img_width), int(landmarks[2].y * img_height)) # nose = (int(landmarks[2].x * img_width), int(landmarks[2].y * img_height))
@ -83,30 +68,9 @@ class MediaPipeClient(Detector):
# right_ear = (int(landmarks[4].x * img_width), int(landmarks[4].y * img_height)) # right_ear = (int(landmarks[4].x * img_width), int(landmarks[4].y * img_height))
# left_ear = (int(landmarks[5].x * img_width), int(landmarks[5].y * img_height)) # left_ear = (int(landmarks[5].x * img_width), int(landmarks[5].y * img_height))
if x > 0 and y > 0: facial_area = FacialAreaRegion(
x=x, y=y, w=w, h=h, left_eye=left_eye, right_eye=right_eye, confidence=confidence
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
if align:
detected_face = detection.align_face(
img=detected_face, left_eye=left_eye, right_eye=right_eye
) )
resp.append(facial_area)
detected_face_obj = DetectedFace(
img=detected_face,
facial_area=img_region,
confidence=confidence,
)
resp.append(detected_face_obj)
return resp return resp

View File

@ -1,8 +1,7 @@
from typing import List from typing import List
import numpy as np import numpy as np
from mtcnn import MTCNN from mtcnn import MTCNN
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.modules import detection
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class MtCnnClient(Detector): class MtCnnClient(Detector):
@ -13,34 +12,19 @@ class MtCnnClient(Detector):
def __init__(self): def __init__(self):
self.model = MTCNN() self.model = MTCNN()
def detect_faces( def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
""" """
Detect and align face with mtcnn Detect and align face with mtcnn
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
resp = [] resp = []
detected_face = None
# mtcnn expects RGB but OpenCV read BGR # mtcnn expects RGB but OpenCV read BGR
# img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_rgb = img[:, :, ::-1] img_rgb = img[:, :, ::-1]
@ -50,31 +34,20 @@ class MtCnnClient(Detector):
for current_detection in detections: for current_detection in detections:
x, y, w, h = current_detection["box"] x, y, w, h = current_detection["box"]
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = current_detection["confidence"] confidence = current_detection["confidence"]
left_eye = current_detection["keypoints"]["left_eye"]
right_eye = current_detection["keypoints"]["right_eye"]
if align: facial_area = FacialAreaRegion(
keypoints = current_detection["keypoints"] x=x,
left_eye = keypoints["left_eye"] y=y,
right_eye = keypoints["right_eye"] w=w,
detected_face = detection.align_face( h=h,
img=detected_face, left_eye=left_eye, right_eye=right_eye left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
) )
detected_face_obj = DetectedFace( resp.append(facial_area)
img=detected_face, facial_area=img_region, confidence=confidence
)
resp.append(detected_face_obj)
return resp return resp

View File

@ -2,8 +2,7 @@ import os
from typing import Any, List from typing import Any, List
import cv2 import cv2
import numpy as np import numpy as np
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.modules import detection
class OpenCvClient(Detector): class OpenCvClient(Detector):
@ -25,28 +24,15 @@ class OpenCvClient(Detector):
detector["eye_detector"] = self.__build_cascade("haarcascade_eye") detector["eye_detector"] = self.__build_cascade("haarcascade_eye")
return detector return detector
def detect_faces( def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
""" """
Detect and align face with opencv Detect and align face with opencv
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
resp = [] resp = []
@ -65,27 +51,18 @@ class OpenCvClient(Detector):
if len(faces) > 0: if len(faces) > 0:
for (x, y, w, h), confidence in zip(faces, scores): for (x, y, w, h), confidence in zip(faces, scores):
detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
if align:
left_eye, right_eye = self.find_eyes(img=detected_face) left_eye, right_eye = self.find_eyes(img=detected_face)
detected_face = detection.align_face(detected_face, left_eye, right_eye) facial_area = FacialAreaRegion(
x=x,
detected_face_obj = DetectedFace( y=y,
img=detected_face, w=w,
facial_area=FacialAreaRegion(x, y, w, h), h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence, confidence=confidence,
) )
resp.append(facial_area)
resp.append(detected_face_obj)
return resp return resp

View File

@ -1,36 +1,22 @@
from typing import List from typing import List
import numpy as np import numpy as np
from retinaface import RetinaFace as rf from retinaface import RetinaFace as rf
from retinaface.commons import postprocess from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class RetinaFaceClient(Detector): class RetinaFaceClient(Detector):
def __init__(self): def __init__(self):
self.model = rf.build_model() self.model = rf.build_model()
def detect_faces( def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
""" """
Detect and align face with retinaface Detect and align face with retinaface
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
resp = [] resp = []
@ -41,42 +27,33 @@ class RetinaFaceClient(Detector):
for face_idx in obj.keys(): for face_idx in obj.keys():
identity = obj[face_idx] identity = obj[face_idx]
facial_area = identity["facial_area"] detection = identity["facial_area"]
y = detection[1]
h = detection[3] - y
x = detection[0]
w = detection[2] - x
# notice that these must be inverse for retinaface
left_eye = identity["landmarks"]["right_eye"]
right_eye = identity["landmarks"]["left_eye"]
# eyes are list of float, need to cast them tuple of int
left_eye = tuple(int(i) for i in left_eye)
right_eye = tuple(int(i) for i in right_eye)
y = facial_area[1]
h = facial_area[3] - y
x = facial_area[0]
w = facial_area[2] - x
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = identity["score"] confidence = identity["score"]
# expand the facial area to be extracted and stay within img.shape limits facial_area = FacialAreaRegion(
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left x=x,
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top y=y,
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right w=w,
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom h=h,
left_eye=left_eye,
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)] right_eye=right_eye,
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
if align:
landmarks = identity["landmarks"]
left_eye = landmarks["left_eye"]
right_eye = landmarks["right_eye"]
nose = landmarks["nose"]
# mouth_right = landmarks["mouth_right"]
# mouth_left = landmarks["mouth_left"]
detected_face = postprocess.alignment_procedure(
detected_face, right_eye, left_eye, nose
)
detected_face_obj = DetectedFace(
img=detected_face,
facial_area=img_region,
confidence=confidence, confidence=confidence,
) )
resp.append(detected_face_obj) resp.append(facial_area)
return resp return resp

View File

@ -6,8 +6,7 @@ import pandas as pd
import numpy as np import numpy as np
from deepface.detectors import OpenCv from deepface.detectors import OpenCv
from deepface.commons import folder_utils from deepface.commons import folder_utils
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.modules import detection
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
logger = Logger(module="detectors.SsdWrapper") logger = Logger(module="detectors.SsdWrapper")
@ -71,29 +70,18 @@ class SsdClient(Detector):
return detector return detector
def detect_faces( def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
""" """
Detect and align face with ssd Detect and align face with ssd
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
opencv_module: OpenCv.OpenCvClient = self.model["opencv_module"]
resp = [] resp = []
detected_face = None detected_face = None
@ -133,37 +121,26 @@ class SsdClient(Detector):
right = instance["right"] right = instance["right"]
bottom = instance["bottom"] bottom = instance["bottom"]
top = instance["top"] top = instance["top"]
confidence = instance["confidence"]
x = int(left * aspect_ratio_x) x = int(left * aspect_ratio_x)
y = int(top * aspect_ratio_y) y = int(top * aspect_ratio_y)
w = int(right * aspect_ratio_x) - int(left * aspect_ratio_x) w = int(right * aspect_ratio_x) - int(left * aspect_ratio_x)
h = int(bottom * aspect_ratio_y) - int(top * aspect_ratio_y) h = int(bottom * aspect_ratio_y) - int(top * aspect_ratio_y)
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
detected_face = img[int(y) : int(y + h), int(x) : int(x + w)] detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
face_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = instance["confidence"]
if align:
opencv_module: OpenCv.OpenCvClient = self.model["opencv_module"]
left_eye, right_eye = opencv_module.find_eyes(detected_face) left_eye, right_eye = opencv_module.find_eyes(detected_face)
detected_face = detection.align_face(
img=detected_face, left_eye=left_eye, right_eye=right_eye
)
detected_face_obj = DetectedFace( facial_area = FacialAreaRegion(
img=detected_face, x=x,
facial_area=face_region, y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence, confidence=confidence,
) )
resp.append(facial_area)
resp.append(detected_face_obj)
return resp return resp

View File

@ -2,8 +2,7 @@ import os
from typing import Any, List from typing import Any, List
import numpy as np import numpy as np
import gdown import gdown
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.modules import detection
from deepface.commons import folder_utils from deepface.commons import folder_utils
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
@ -50,28 +49,15 @@ class YoloClient(Detector):
# Return face_detector # Return face_detector
return YOLO(weight_path) return YOLO(weight_path)
def detect_faces( def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
self, img: np.ndarray, align: bool = False, expand_percentage: int = 0
) -> List[DetectedFace]:
""" """
Detect and align face with yolo Detect and align face with yolo
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
resp = [] resp = []
@ -84,36 +70,25 @@ class YoloClient(Detector):
x, y, w, h = result.boxes.xywh.tolist()[0] x, y, w, h = result.boxes.xywh.tolist()[0]
confidence = result.boxes.conf.tolist()[0] confidence = result.boxes.conf.tolist()[0]
# left_eye_conf = result.keypoints.conf[0][0]
# right_eye_conf = result.keypoints.conf[0][1]
left_eye = result.keypoints.xy[0][0].tolist()
right_eye = result.keypoints.xy[0][1].tolist()
# eyes are list of float, need to cast them tuple of int
left_eye = tuple(int(i) for i in left_eye)
right_eye = tuple(int(i) for i in right_eye)
x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h) x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h)
region = FacialAreaRegion(x=x, y=y, w=w, h=h) facial_area = FacialAreaRegion(
x=x,
# expand the facial area to be extracted and stay within img.shape limits y=y,
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left w=w,
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top h=h,
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right left_eye=left_eye,
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom right_eye=right_eye,
confidence=confidence,
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
if align:
# Tuple of x,y and confidence for left eye
left_eye = result.keypoints.xy[0][0], result.keypoints.conf[0][0]
# Tuple of x,y and confidence for right eye
right_eye = result.keypoints.xy[0][1], result.keypoints.conf[0][1]
# Check the landmarks confidence before alignment
if (
left_eye[1] > LANDMARKS_CONFIDENCE_THRESHOLD
and right_eye[1] > LANDMARKS_CONFIDENCE_THRESHOLD
):
detected_face = detection.align_face(
img=detected_face, left_eye=left_eye[0].cpu(), right_eye=right_eye[0].cpu()
) )
resp.append(facial_area)
detected_face_obj = DetectedFace(
img=detected_face, facial_area=region, confidence=confidence
)
resp.append(detected_face_obj)
return resp return resp

View File

@ -4,8 +4,7 @@ import cv2
import numpy as np import numpy as np
import gdown import gdown
from deepface.commons import folder_utils from deepface.commons import folder_utils
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.modules import detection
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
logger = Logger(module="detectors.YunetWrapper") logger = Logger(module="detectors.YunetWrapper")
@ -49,34 +48,20 @@ class YuNetClient(Detector):
) from err ) from err
return face_detector return face_detector
def detect_faces( def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
""" """
Detect and align face with yunet Detect and align face with yunet
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
# FaceDetector.detect_faces does not support score_threshold parameter. # FaceDetector.detect_faces does not support score_threshold parameter.
# We can set it via environment variable. # We can set it via environment variable.
score_threshold = float(os.environ.get("yunet_score_threshold", "0.9")) score_threshold = float(os.environ.get("yunet_score_threshold", "0.9"))
resp = [] resp = []
detected_face = None
faces = [] faces = []
height, width = img.shape[0], img.shape[1] height, width = img.shape[0], img.shape[1]
# resize image if it is too large (Yunet fails to detect faces on large input sometimes) # resize image if it is too large (Yunet fails to detect faces on large input sometimes)
@ -108,6 +93,8 @@ class YuNetClient(Detector):
left eye, nose tip, the right corner and left corner of the mouth respectively. left eye, nose tip, the right corner and left corner of the mouth respectively.
""" """
(x, y, w, h, x_re, y_re, x_le, y_le) = list(map(int, face[:8])) (x, y, w, h, x_re, y_re, x_le, y_le) = list(map(int, face[:8]))
left_eye = (x_re, y_re)
right_eye = (x_le, y_le)
# Yunet returns negative coordinates if it thinks part of # Yunet returns negative coordinates if it thinks part of
# the detected face is outside the frame. # the detected face is outside the frame.
@ -123,24 +110,16 @@ class YuNetClient(Detector):
int(x_le / r), int(x_le / r),
int(y_le / r), int(y_le / r),
) )
confidence = face[-1] confidence = float(face[-1])
confidence = f"{confidence:.2f}"
# expand the facial area to be extracted and stay within img.shape limits facial_area = FacialAreaRegion(
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left x=x,
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top y=y,
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right w=w,
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom h=h,
confidence=confidence,
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)] left_eye=left_eye,
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)] right_eye=right_eye,
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
if align:
detected_face = detection.align_face(detected_face, (x_re, y_re), (x_le, y_le))
detected_face_obj = DetectedFace(
img=detected_face, facial_area=img_region, confidence=confidence
) )
resp.append(detected_face_obj) resp.append(facial_area)
return resp return resp

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Tuple, Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
@ -8,28 +8,19 @@ import numpy as np
# pylint: disable=unnecessary-pass, too-few-public-methods # pylint: disable=unnecessary-pass, too-few-public-methods
class Detector(ABC): class Detector(ABC):
@abstractmethod @abstractmethod
def detect_faces( def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]:
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List["DetectedFace"]:
""" """
Interface for detect and align face Interface for detect and align face
Args: Args:
img (np.ndarray): pre-loaded image as numpy array img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns: Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains: where each object contains:
- img (np.ndarray): The detected face as a NumPy array. - facial_area (FacialAreaRegion): The facial area region represented
as x, y, w, h, left_eye and right_eye
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
""" """
pass pass
@ -39,12 +30,27 @@ class FacialAreaRegion:
y: int y: int
w: int w: int
h: int h: int
left_eye: Tuple[int, int]
right_eye: Tuple[int, int]
confidence: float
def __init__(self, x: int, y: int, w: int, h: int): def __init__(
self,
x: int,
y: int,
w: int,
h: int,
left_eye: Optional[Tuple[int, int]] = None,
right_eye: Optional[Tuple[int, int]] = None,
confidence: Optional[float] = None,
):
self.x = x self.x = x
self.y = y self.y = y
self.w = w self.w = w
self.h = h self.h = h
self.left_eye = left_eye
self.right_eye = right_eye
self.confidence = confidence
class DetectedFace: class DetectedFace:

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union, Optional
# 3rd part dependencies # 3rd part dependencies
import numpy as np import numpy as np
@ -27,7 +27,7 @@ elif tf_major_version == 2:
def extract_faces( def extract_faces(
img_path: Union[str, np.ndarray], img_path: Union[str, np.ndarray],
target_size: Tuple[int, int] = (224, 224), target_size: Optional[Tuple[int, int]] = (224, 224),
detector_backend: str = "opencv", detector_backend: str = "opencv",
enforce_detection: bool = True, enforce_detection: bool = True,
align: bool = True, align: bool = True,
@ -76,7 +76,7 @@ def extract_faces(
# img might be path, base64 or numpy array. Convert it to numpy whatever it is. # img might be path, base64 or numpy array. Convert it to numpy whatever it is.
img, img_name = preprocessing.load_image(img_path) img, img_name = preprocessing.load_image(img_path)
base_region = FacialAreaRegion(x=0, y=0, w=img.shape[1], h=img.shape[0]) base_region = FacialAreaRegion(x=0, y=0, w=img.shape[1], h=img.shape[0], confidence=0)
if detector_backend == "skip": if detector_backend == "skip":
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)] face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
@ -108,7 +108,6 @@ def extract_faces(
for face_obj in face_objs: for face_obj in face_objs:
current_img = face_obj.img current_img = face_obj.img
current_region = face_obj.facial_area current_region = face_obj.facial_area
confidence = face_obj.confidence
if current_img.shape[0] == 0 or current_img.shape[1] == 0: if current_img.shape[0] == 0 or current_img.shape[1] == 0:
continue continue
@ -117,6 +116,7 @@ def extract_faces(
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY) current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
# resize and padding # resize and padding
if target_size is not None:
factor_0 = target_size[0] / current_img.shape[0] factor_0 = target_size[0] / current_img.shape[0]
factor_1 = target_size[1] / current_img.shape[1] factor_1 = target_size[1] / current_img.shape[1]
factor = min(factor_0, factor_1) factor = min(factor_0, factor_1)
@ -171,14 +171,17 @@ def extract_faces(
"y": int(current_region.y), "y": int(current_region.y),
"w": int(current_region.w), "w": int(current_region.w),
"h": int(current_region.h), "h": int(current_region.h),
"left_eye": current_region.left_eye,
"right_eye": current_region.right_eye,
}, },
"confidence": confidence, "confidence": round(current_region.confidence, 2),
} }
) )
if len(resp_objs) == 0 and enforce_detection == True: if len(resp_objs) == 0 and enforce_detection == True:
raise ValueError( raise ValueError(
f"Detected face shape is {img.shape}. Consider to set enforce_detection arg to False." f"Exception while extracting faces from {img_name}."
"Consider to set enforce_detection arg to False."
) )
return resp_objs return resp_objs
@ -188,7 +191,7 @@ def align_face(
img: np.ndarray, img: np.ndarray,
left_eye: Union[list, tuple], left_eye: Union[list, tuple],
right_eye: Union[list, tuple], right_eye: Union[list, tuple],
) -> np.ndarray: ) -> Tuple[np.ndarray, float]:
""" """
Align a given image horizantally with respect to their left and right eye locations Align a given image horizantally with respect to their left and right eye locations
Args: Args:
@ -200,13 +203,13 @@ def align_face(
""" """
# if eye could not be detected for the given image, return image itself # if eye could not be detected for the given image, return image itself
if left_eye is None or right_eye is None: if left_eye is None or right_eye is None:
return img return img, 0
# sometimes unexpectedly detected images come with nil dimensions # sometimes unexpectedly detected images come with nil dimensions
if img.shape[0] == 0 or img.shape[1] == 0: if img.shape[0] == 0 or img.shape[1] == 0:
return img return img, 0
angle = float(np.degrees(np.arctan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0]))) angle = float(np.degrees(np.arctan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0])))
img = Image.fromarray(img) img = Image.fromarray(img)
img = np.array(img.rotate(angle)) img = np.array(img.rotate(angle))
return img return img, angle

View File

@ -354,6 +354,7 @@ def __find_bulk_embeddings(
desc="Finding representations", desc="Finding representations",
disable=silent, disable=silent,
): ):
try:
img_objs = detection.extract_faces( img_objs = detection.extract_faces(
img_path=employee, img_path=employee,
target_size=target_size, target_size=target_size,
@ -363,6 +364,11 @@ def __find_bulk_embeddings(
align=align, align=align,
expand_percentage=expand_percentage, expand_percentage=expand_percentage,
) )
except ValueError as err:
logger.warn(
f"Exception while extracting faces from {employee}: {str(err)}. Skipping it."
)
img_objs = []
for img_obj in img_objs: for img_obj in img_objs:
img_content = img_obj["face"] img_content = img_obj["face"]

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB

View File

@ -53,14 +53,37 @@ dfs = DeepFace.find(
for df in dfs: for df in dfs:
logger.info(df) logger.info(df)
# img_paths = ["dataset/img11.jpg", "dataset/img11_reflection.jpg", "dataset/couple.jpg"]
img_paths = ["dataset/img11.jpg"]
for img_path in img_paths:
# extract faces # extract faces
for detector_backend in detector_backends: for detector_backend in detector_backends:
face_objs = DeepFace.extract_faces( face_objs = DeepFace.extract_faces(
img_path="dataset/img11.jpg", detector_backend=detector_backend img_path=img_path,
detector_backend=detector_backend,
align=True,
# expand_percentage=10,
# target_size=None,
) )
for face_obj in face_objs: for face_obj in face_objs:
face = face_obj["face"] face = face_obj["face"]
logger.info(detector_backend) logger.info(detector_backend)
logger.info(face_obj["facial_area"])
logger.info(face_obj["confidence"])
# we know opencv sometimes cannot find eyes
if face_obj["facial_area"]["left_eye"] is not None:
assert isinstance(face_obj["facial_area"]["left_eye"], tuple)
assert isinstance(face_obj["facial_area"]["left_eye"][0], int)
assert isinstance(face_obj["facial_area"]["left_eye"][1], int)
if face_obj["facial_area"]["right_eye"] is not None:
assert isinstance(face_obj["facial_area"]["right_eye"], tuple)
assert isinstance(face_obj["facial_area"]["right_eye"][0], int)
assert isinstance(face_obj["facial_area"]["right_eye"][1], int)
assert isinstance(face_obj["confidence"], float)
plt.imshow(face) plt.imshow(face)
plt.axis("off") plt.axis("off")
plt.show() plt.show()