Merge pull request #1025 from serengil/feat-task-1602-new-detector-interface

Feat task 1602 new detector interface
This commit is contained in:
Sefik Ilkin Serengil 2024-02-16 17:50:50 +00:00 committed by GitHub
commit ee4ad41c2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 357 additions and 462 deletions

View File

@ -423,7 +423,7 @@ def stream(
def extract_faces(
img_path: Union[str, np.ndarray],
target_size: Tuple[int, int] = (224, 224),
target_size: Optional[Tuple[int, int]] = (224, 224),
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,

View File

@ -1,6 +1,7 @@
from typing import Any, List
from typing import Any, List, Tuple
import numpy as np
from deepface.models.Detector import Detector, DetectedFace
from deepface.modules import detection
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.detectors import (
FastMtCnn,
MediaPipe,
@ -80,10 +81,101 @@ def detect_faces(
- confidence (float): The confidence score associated with the detected face.
"""
face_detector: Detector = build_model(detector_backend)
# validate expand percentage score
if expand_percentage < 0:
logger.warn(
f"Expand percentage cannot be negative but you set it to {expand_percentage}."
"Overwritten it to 0."
)
expand_percentage = 0
return face_detector.detect_faces(img=img, align=align, expand_percentage=expand_percentage)
# find facial areas of given image
facial_areas = face_detector.detect_faces(img=img)
results = []
for facial_area in facial_areas:
x = facial_area.x
y = facial_area.y
w = facial_area.w
h = facial_area.h
left_eye = facial_area.left_eye
right_eye = facial_area.right_eye
confidence = facial_area.confidence
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * 2 * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * 2 * expand_percentage) / 100)) # expand bottom
# extract detected face unaligned
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
# aligning detected face causes a lot of black pixels
# if align is True:
# detected_face, _ = detection.align_face(
# img=detected_face, left_eye=left_eye, right_eye=right_eye
# )
# align original image, then find projection of detected face area after alignment
if align is True: # and left_eye is not None and right_eye is not None:
aligned_img, angle = detection.align_face(
img=img, left_eye=left_eye, right_eye=right_eye
)
x1_new, y1_new, x2_new, y2_new = rotate_facial_area(
facial_area=(x2, y2, x2 + w2, y2 + h2), angle=angle, direction=1, size=img.shape
)
detected_face = aligned_img[int(y1_new) : int(y2_new), int(x1_new) : int(x2_new)]
result = DetectedFace(
img=detected_face,
facial_area=FacialAreaRegion(
x=x, y=y, h=h, w=w, confidence=confidence, left_eye=left_eye, right_eye=right_eye
),
confidence=confidence,
)
results.append(result)
return results
def rotate_facial_area(
facial_area: Tuple[int, int, int, int], angle: float, direction: int, size: Tuple[int, int]
) -> Tuple[int, int, int, int]:
"""
Rotate the facial area around its center.
Inspried from the work of @UmutDeniz26 - github.com/serengil/retinaface/pull/80
Args:
facial_area (tuple of int): Representing the (x1, y1, x2, y2) of the facial area.
x2 is equal to x1 + w1, and y2 is equal to y1 + h1
angle (float): Angle of rotation in degrees.
direction (int): Direction of rotation (-1 for clockwise, 1 for counterclockwise).
size (tuple of int): Tuple representing the size of the image (width, height).
Returns:
rotated_coordinates (tuple of int): Representing the new coordinates
(x1, y1, x2, y2) or (x1, y1, x1+w1, y1+h1) of the rotated facial area.
"""
# Angle in radians
angle = angle * np.pi / 180
# Translate the facial area to the center of the image
x = (facial_area[0] + facial_area[2]) / 2 - size[1] / 2
y = (facial_area[1] + facial_area[3]) / 2 - size[0] / 2
# Rotate the facial area
x_new = x * np.cos(angle) + y * direction * np.sin(angle)
y_new = -x * direction * np.sin(angle) + y * np.cos(angle)
# Translate the facial area back to the original position
x_new = x_new + size[1] / 2
y_new = y_new + size[0] / 2
# Calculate the new facial area
x1 = x_new - (facial_area[2] - facial_area[0]) / 2
y1 = y_new - (facial_area[3] - facial_area[1]) / 2
x2 = x_new + (facial_area[2] - facial_area[0]) / 2
y2 = y_new + (facial_area[3] - facial_area[1]) / 2
return (int(x1), int(y1), int(x2), int(y2))

View File

@ -4,7 +4,7 @@ import bz2
import gdown
import numpy as np
from deepface.commons import folder_utils
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons.logger import Logger
logger = Logger(module="detectors.DlibWrapper")
@ -56,50 +56,18 @@ class DlibClient(Detector):
detector["sp"] = sp
return detector
def detect_faces(
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Detect and align face with dlib
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
# this is not a must dependency. do not import it in the global level.
try:
import dlib
except ModuleNotFoundError as e:
raise ImportError(
"Dlib is an optional detector, ensure the library is installed."
"Please install using 'pip install dlib' "
) from e
if expand_percentage != 0:
logger.warn(
f"You set expand_percentage argument to {expand_percentage},"
"but dlib hog handles detection by itself"
)
resp = []
sp = self.model["sp"]
detected_face = None
face_detector = self.model["face_detector"]
# note that, by design, dlib's fhog face detector scores are >0 but not capped at 1
@ -107,30 +75,32 @@ class DlibClient(Detector):
if len(detections) > 0:
for idx, d in enumerate(detections):
left = d.left()
right = d.right()
top = d.top()
bottom = d.bottom()
for idx, detection in enumerate(detections):
left = detection.left()
right = detection.right()
top = detection.top()
bottom = detection.bottom()
y = int(max(0, top))
h = int(min(bottom, img.shape[0]) - y)
x = int(max(0, left))
w = int(min(right, img.shape[1]) - x)
detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
shape = self.model["sp"](img, detection)
left_eye = (shape.part(2).x, shape.part(2).y)
right_eye = (shape.part(0).x, shape.part(0).y)
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = scores[idx]
if align:
img_shape = sp(img, detections[idx])
detected_face = dlib.get_face_chip(img, img_shape, size=detected_face.shape[0])
detected_face_obj = DetectedFace(
img=detected_face, facial_area=img_region, confidence=confidence
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
)
resp.append(detected_face_obj)
resp.append(facial_area)
return resp

View File

@ -1,8 +1,7 @@
from typing import Any, Union, List
import cv2
import numpy as np
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.modules import detection
from deepface.models.Detector import Detector, FacialAreaRegion
# Link -> https://github.com/timesler/facenet-pytorch
# Examples https://www.kaggle.com/timesler/guide-to-mtcnn-in-facenet-pytorch
@ -12,33 +11,18 @@ class FastMtCnnClient(Detector):
def __init__(self):
self.model = self.build_model()
def detect_faces(
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Detect and align face with mtcnn
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
resp = []
detected_face = None
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # mtcnn expects RGB but OpenCV read BGR
detections = self.model.detect(
img_rgb, landmarks=True
@ -47,31 +31,20 @@ class FastMtCnnClient(Detector):
for current_detection in zip(*detections):
x, y, w, h = xyxy_to_xywh(current_detection[0])
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = current_detection[1]
if align:
left_eye = current_detection[2][0]
right_eye = current_detection[2][1]
detected_face = detection.align_face(
img=detected_face, left_eye=left_eye, right_eye=right_eye
)
detected_face_obj = DetectedFace(
img=detected_face, facial_area=img_region, confidence=confidence
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
)
resp.append(detected_face_obj)
resp.append(facial_area)
return resp

View File

@ -1,7 +1,6 @@
from typing import Any, List
import numpy as np
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.modules import detection
from deepface.models.Detector import Detector, FacialAreaRegion
# Link - https://google.github.io/mediapipe/solutions/face_detection
@ -29,28 +28,15 @@ class MediaPipeClient(Detector):
face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.7)
return face_detection
def detect_faces(
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Detect and align face with mediapipe
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
resp = []
@ -75,7 +61,6 @@ class MediaPipeClient(Detector):
y = int(bounding_box.ymin * img_height)
h = int(bounding_box.height * img_height)
# Extract landmarks
left_eye = (int(landmarks[0].x * img_width), int(landmarks[0].y * img_height))
right_eye = (int(landmarks[1].x * img_width), int(landmarks[1].y * img_height))
# nose = (int(landmarks[2].x * img_width), int(landmarks[2].y * img_height))
@ -83,30 +68,9 @@ class MediaPipeClient(Detector):
# right_ear = (int(landmarks[4].x * img_width), int(landmarks[4].y * img_height))
# left_ear = (int(landmarks[5].x * img_width), int(landmarks[5].y * img_height))
if x > 0 and y > 0:
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
if align:
detected_face = detection.align_face(
img=detected_face, left_eye=left_eye, right_eye=right_eye
facial_area = FacialAreaRegion(
x=x, y=y, w=w, h=h, left_eye=left_eye, right_eye=right_eye, confidence=confidence
)
detected_face_obj = DetectedFace(
img=detected_face,
facial_area=img_region,
confidence=confidence,
)
resp.append(detected_face_obj)
resp.append(facial_area)
return resp

View File

@ -1,8 +1,7 @@
from typing import List
import numpy as np
from mtcnn import MTCNN
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.modules import detection
from deepface.models.Detector import Detector, FacialAreaRegion
# pylint: disable=too-few-public-methods
class MtCnnClient(Detector):
@ -13,34 +12,19 @@ class MtCnnClient(Detector):
def __init__(self):
self.model = MTCNN()
def detect_faces(
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Detect and align face with mtcnn
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
resp = []
detected_face = None
# mtcnn expects RGB but OpenCV read BGR
# img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_rgb = img[:, :, ::-1]
@ -50,31 +34,20 @@ class MtCnnClient(Detector):
for current_detection in detections:
x, y, w, h = current_detection["box"]
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = current_detection["confidence"]
left_eye = current_detection["keypoints"]["left_eye"]
right_eye = current_detection["keypoints"]["right_eye"]
if align:
keypoints = current_detection["keypoints"]
left_eye = keypoints["left_eye"]
right_eye = keypoints["right_eye"]
detected_face = detection.align_face(
img=detected_face, left_eye=left_eye, right_eye=right_eye
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
)
detected_face_obj = DetectedFace(
img=detected_face, facial_area=img_region, confidence=confidence
)
resp.append(detected_face_obj)
resp.append(facial_area)
return resp

View File

@ -2,8 +2,7 @@ import os
from typing import Any, List
import cv2
import numpy as np
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.modules import detection
from deepface.models.Detector import Detector, FacialAreaRegion
class OpenCvClient(Detector):
@ -25,28 +24,15 @@ class OpenCvClient(Detector):
detector["eye_detector"] = self.__build_cascade("haarcascade_eye")
return detector
def detect_faces(
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Detect and align face with opencv
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
resp = []
@ -65,27 +51,18 @@ class OpenCvClient(Detector):
if len(faces) > 0:
for (x, y, w, h), confidence in zip(faces, scores):
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
if align:
detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
left_eye, right_eye = self.find_eyes(img=detected_face)
detected_face = detection.align_face(detected_face, left_eye, right_eye)
detected_face_obj = DetectedFace(
img=detected_face,
facial_area=FacialAreaRegion(x, y, w, h),
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
)
resp.append(detected_face_obj)
resp.append(facial_area)
return resp

View File

@ -1,36 +1,22 @@
from typing import List
import numpy as np
from retinaface import RetinaFace as rf
from retinaface.commons import postprocess
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.models.Detector import Detector, FacialAreaRegion
# pylint: disable=too-few-public-methods
class RetinaFaceClient(Detector):
def __init__(self):
self.model = rf.build_model()
def detect_faces(
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Detect and align face with retinaface
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
resp = []
@ -41,42 +27,33 @@ class RetinaFaceClient(Detector):
for face_idx in obj.keys():
identity = obj[face_idx]
facial_area = identity["facial_area"]
detection = identity["facial_area"]
y = detection[1]
h = detection[3] - y
x = detection[0]
w = detection[2] - x
# notice that these must be inverse for retinaface
left_eye = identity["landmarks"]["right_eye"]
right_eye = identity["landmarks"]["left_eye"]
# eyes are list of float, need to cast them tuple of int
left_eye = tuple(int(i) for i in left_eye)
right_eye = tuple(int(i) for i in right_eye)
y = facial_area[1]
h = facial_area[3] - y
x = facial_area[0]
w = facial_area[2] - x
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = identity["score"]
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
if align:
landmarks = identity["landmarks"]
left_eye = landmarks["left_eye"]
right_eye = landmarks["right_eye"]
nose = landmarks["nose"]
# mouth_right = landmarks["mouth_right"]
# mouth_left = landmarks["mouth_left"]
detected_face = postprocess.alignment_procedure(
detected_face, right_eye, left_eye, nose
)
detected_face_obj = DetectedFace(
img=detected_face,
facial_area=img_region,
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
)
resp.append(detected_face_obj)
resp.append(facial_area)
return resp

View File

@ -6,8 +6,7 @@ import pandas as pd
import numpy as np
from deepface.detectors import OpenCv
from deepface.commons import folder_utils
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.modules import detection
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons.logger import Logger
logger = Logger(module="detectors.SsdWrapper")
@ -71,29 +70,18 @@ class SsdClient(Detector):
return detector
def detect_faces(
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Detect and align face with ssd
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
opencv_module: OpenCv.OpenCvClient = self.model["opencv_module"]
resp = []
detected_face = None
@ -133,37 +121,26 @@ class SsdClient(Detector):
right = instance["right"]
bottom = instance["bottom"]
top = instance["top"]
confidence = instance["confidence"]
x = int(left * aspect_ratio_x)
y = int(top * aspect_ratio_y)
w = int(right * aspect_ratio_x) - int(left * aspect_ratio_x)
h = int(bottom * aspect_ratio_y) - int(top * aspect_ratio_y)
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
face_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
confidence = instance["confidence"]
if align:
opencv_module: OpenCv.OpenCvClient = self.model["opencv_module"]
left_eye, right_eye = opencv_module.find_eyes(detected_face)
detected_face = detection.align_face(
img=detected_face, left_eye=left_eye, right_eye=right_eye
)
detected_face_obj = DetectedFace(
img=detected_face,
facial_area=face_region,
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
)
resp.append(facial_area)
resp.append(detected_face_obj)
return resp

View File

@ -2,8 +2,7 @@ import os
from typing import Any, List
import numpy as np
import gdown
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.modules import detection
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons import folder_utils
from deepface.commons.logger import Logger
@ -50,28 +49,15 @@ class YoloClient(Detector):
# Return face_detector
return YOLO(weight_path)
def detect_faces(
self, img: np.ndarray, align: bool = False, expand_percentage: int = 0
) -> List[DetectedFace]:
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Detect and align face with yolo
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
resp = []
@ -84,36 +70,25 @@ class YoloClient(Detector):
x, y, w, h = result.boxes.xywh.tolist()[0]
confidence = result.boxes.conf.tolist()[0]
# left_eye_conf = result.keypoints.conf[0][0]
# right_eye_conf = result.keypoints.conf[0][1]
left_eye = result.keypoints.xy[0][0].tolist()
right_eye = result.keypoints.xy[0][1].tolist()
# eyes are list of float, need to cast them tuple of int
left_eye = tuple(int(i) for i in left_eye)
right_eye = tuple(int(i) for i in right_eye)
x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h)
region = FacialAreaRegion(x=x, y=y, w=w, h=h)
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
if align:
# Tuple of x,y and confidence for left eye
left_eye = result.keypoints.xy[0][0], result.keypoints.conf[0][0]
# Tuple of x,y and confidence for right eye
right_eye = result.keypoints.xy[0][1], result.keypoints.conf[0][1]
# Check the landmarks confidence before alignment
if (
left_eye[1] > LANDMARKS_CONFIDENCE_THRESHOLD
and right_eye[1] > LANDMARKS_CONFIDENCE_THRESHOLD
):
detected_face = detection.align_face(
img=detected_face, left_eye=left_eye[0].cpu(), right_eye=right_eye[0].cpu()
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
left_eye=left_eye,
right_eye=right_eye,
confidence=confidence,
)
detected_face_obj = DetectedFace(
img=detected_face, facial_area=region, confidence=confidence
)
resp.append(detected_face_obj)
resp.append(facial_area)
return resp

View File

@ -4,8 +4,7 @@ import cv2
import numpy as np
import gdown
from deepface.commons import folder_utils
from deepface.models.Detector import Detector, DetectedFace, FacialAreaRegion
from deepface.modules import detection
from deepface.models.Detector import Detector, FacialAreaRegion
from deepface.commons.logger import Logger
logger = Logger(module="detectors.YunetWrapper")
@ -49,34 +48,20 @@ class YuNetClient(Detector):
) from err
return face_detector
def detect_faces(
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List[DetectedFace]:
def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]:
"""
Detect and align face with yunet
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
"""
# FaceDetector.detect_faces does not support score_threshold parameter.
# We can set it via environment variable.
score_threshold = float(os.environ.get("yunet_score_threshold", "0.9"))
resp = []
detected_face = None
faces = []
height, width = img.shape[0], img.shape[1]
# resize image if it is too large (Yunet fails to detect faces on large input sometimes)
@ -108,6 +93,8 @@ class YuNetClient(Detector):
left eye, nose tip, the right corner and left corner of the mouth respectively.
"""
(x, y, w, h, x_re, y_re, x_le, y_le) = list(map(int, face[:8]))
left_eye = (x_re, y_re)
right_eye = (x_le, y_le)
# Yunet returns negative coordinates if it thinks part of
# the detected face is outside the frame.
@ -123,24 +110,16 @@ class YuNetClient(Detector):
int(x_le / r),
int(y_le / r),
)
confidence = face[-1]
confidence = f"{confidence:.2f}"
confidence = float(face[-1])
# expand the facial area to be extracted and stay within img.shape limits
x2 = max(0, x - int((w * expand_percentage) / 100)) # expand left
y2 = max(0, y - int((h * expand_percentage) / 100)) # expand top
w2 = min(img.shape[1], w + int((w * expand_percentage) / 100)) # expand right
h2 = min(img.shape[0], h + int((h * expand_percentage) / 100)) # expand bottom
# detected_face = img[int(y) : int(y + h), int(x) : int(x + w)]
detected_face = img[int(y2) : int(y2 + h2), int(x2) : int(x2 + w2)]
img_region = FacialAreaRegion(x=x, y=y, w=w, h=h)
if align:
detected_face = detection.align_face(detected_face, (x_re, y_re), (x_le, y_le))
detected_face_obj = DetectedFace(
img=detected_face, facial_area=img_region, confidence=confidence
facial_area = FacialAreaRegion(
x=x,
y=y,
w=w,
h=h,
confidence=confidence,
left_eye=left_eye,
right_eye=right_eye,
)
resp.append(detected_face_obj)
resp.append(facial_area)
return resp

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple, Optional
from abc import ABC, abstractmethod
import numpy as np
@ -8,28 +8,19 @@ import numpy as np
# pylint: disable=unnecessary-pass, too-few-public-methods
class Detector(ABC):
@abstractmethod
def detect_faces(
self, img: np.ndarray, align: bool = True, expand_percentage: int = 0
) -> List["DetectedFace"]:
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]:
"""
Interface for detect and align face
Args:
img (np.ndarray): pre-loaded image as numpy array
align (bool): flag to enable or disable alignment after detection (default is True)
expand_percentage (int): expand detected facial area with a percentage
Returns:
results (List[Tuple[DetectedFace]): A list of DetectedFace objects
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
where each object contains:
- img (np.ndarray): The detected face as a NumPy array.
- facial_area (FacialAreaRegion): The facial area region represented as x, y, w, h
- confidence (float): The confidence score associated with the detected face.
- facial_area (FacialAreaRegion): The facial area region represented
as x, y, w, h, left_eye and right_eye
"""
pass
@ -39,12 +30,27 @@ class FacialAreaRegion:
y: int
w: int
h: int
left_eye: Tuple[int, int]
right_eye: Tuple[int, int]
confidence: float
def __init__(self, x: int, y: int, w: int, h: int):
def __init__(
self,
x: int,
y: int,
w: int,
h: int,
left_eye: Optional[Tuple[int, int]] = None,
right_eye: Optional[Tuple[int, int]] = None,
confidence: Optional[float] = None,
):
self.x = x
self.y = y
self.w = w
self.h = h
self.left_eye = left_eye
self.right_eye = right_eye
self.confidence = confidence
class DetectedFace:

View File

@ -1,5 +1,5 @@
# built-in dependencies
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union, Optional
# 3rd part dependencies
import numpy as np
@ -27,7 +27,7 @@ elif tf_major_version == 2:
def extract_faces(
img_path: Union[str, np.ndarray],
target_size: Tuple[int, int] = (224, 224),
target_size: Optional[Tuple[int, int]] = (224, 224),
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,
@ -76,7 +76,7 @@ def extract_faces(
# img might be path, base64 or numpy array. Convert it to numpy whatever it is.
img, img_name = preprocessing.load_image(img_path)
base_region = FacialAreaRegion(x=0, y=0, w=img.shape[1], h=img.shape[0])
base_region = FacialAreaRegion(x=0, y=0, w=img.shape[1], h=img.shape[0], confidence=0)
if detector_backend == "skip":
face_objs = [DetectedFace(img=img, facial_area=base_region, confidence=0)]
@ -108,7 +108,6 @@ def extract_faces(
for face_obj in face_objs:
current_img = face_obj.img
current_region = face_obj.facial_area
confidence = face_obj.confidence
if current_img.shape[0] == 0 or current_img.shape[1] == 0:
continue
@ -117,6 +116,7 @@ def extract_faces(
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
# resize and padding
if target_size is not None:
factor_0 = target_size[0] / current_img.shape[0]
factor_1 = target_size[1] / current_img.shape[1]
factor = min(factor_0, factor_1)
@ -171,14 +171,17 @@ def extract_faces(
"y": int(current_region.y),
"w": int(current_region.w),
"h": int(current_region.h),
"left_eye": current_region.left_eye,
"right_eye": current_region.right_eye,
},
"confidence": confidence,
"confidence": round(current_region.confidence, 2),
}
)
if len(resp_objs) == 0 and enforce_detection == True:
raise ValueError(
f"Detected face shape is {img.shape}. Consider to set enforce_detection arg to False."
f"Exception while extracting faces from {img_name}."
"Consider to set enforce_detection arg to False."
)
return resp_objs
@ -188,7 +191,7 @@ def align_face(
img: np.ndarray,
left_eye: Union[list, tuple],
right_eye: Union[list, tuple],
) -> np.ndarray:
) -> Tuple[np.ndarray, float]:
"""
Align a given image horizantally with respect to their left and right eye locations
Args:
@ -200,13 +203,13 @@ def align_face(
"""
# if eye could not be detected for the given image, return image itself
if left_eye is None or right_eye is None:
return img
return img, 0
# sometimes unexpectedly detected images come with nil dimensions
if img.shape[0] == 0 or img.shape[1] == 0:
return img
return img, 0
angle = float(np.degrees(np.arctan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0])))
img = Image.fromarray(img)
img = np.array(img.rotate(angle))
return img
return img, angle

View File

@ -354,6 +354,7 @@ def __find_bulk_embeddings(
desc="Finding representations",
disable=silent,
):
try:
img_objs = detection.extract_faces(
img_path=employee,
target_size=target_size,
@ -363,6 +364,11 @@ def __find_bulk_embeddings(
align=align,
expand_percentage=expand_percentage,
)
except ValueError as err:
logger.warn(
f"Exception while extracting faces from {employee}: {str(err)}. Skipping it."
)
img_objs = []
for img_obj in img_objs:
img_content = img_obj["face"]

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB

View File

@ -53,14 +53,37 @@ dfs = DeepFace.find(
for df in dfs:
logger.info(df)
# img_paths = ["dataset/img11.jpg", "dataset/img11_reflection.jpg", "dataset/couple.jpg"]
img_paths = ["dataset/img11.jpg"]
for img_path in img_paths:
# extract faces
for detector_backend in detector_backends:
face_objs = DeepFace.extract_faces(
img_path="dataset/img11.jpg", detector_backend=detector_backend
img_path=img_path,
detector_backend=detector_backend,
align=True,
# expand_percentage=10,
# target_size=None,
)
for face_obj in face_objs:
face = face_obj["face"]
logger.info(detector_backend)
logger.info(face_obj["facial_area"])
logger.info(face_obj["confidence"])
# we know opencv sometimes cannot find eyes
if face_obj["facial_area"]["left_eye"] is not None:
assert isinstance(face_obj["facial_area"]["left_eye"], tuple)
assert isinstance(face_obj["facial_area"]["left_eye"][0], int)
assert isinstance(face_obj["facial_area"]["left_eye"][1], int)
if face_obj["facial_area"]["right_eye"] is not None:
assert isinstance(face_obj["facial_area"]["right_eye"], tuple)
assert isinstance(face_obj["facial_area"]["right_eye"][0], int)
assert isinstance(face_obj["facial_area"]["right_eye"][1], int)
assert isinstance(face_obj["confidence"], float)
plt.imshow(face)
plt.axis("off")
plt.show()