mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 03:55:21 +00:00
pseudo-batching centerface
This commit is contained in:
parent
1d358aa15a
commit
f5188c802c
@ -1,6 +1,6 @@
|
|||||||
# built-in dependencies
|
# built-in dependencies
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -34,12 +34,29 @@ class CenterFaceClient(Detector):
|
|||||||
|
|
||||||
return CenterFace(weight_path=weights_path)
|
return CenterFace(weight_path=weights_path)
|
||||||
|
|
||||||
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]:
|
def detect_faces(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
|
||||||
"""
|
"""
|
||||||
Detect and align face with CenterFace
|
Detect and align face with CenterFace
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img (np.ndarray): pre-loaded image as numpy array
|
img (Union[np.ndarray, List[np.ndarray]]): pre-loaded image as numpy array or a list of those
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): A list or a list of lists of FacialAreaRegion objects
|
||||||
|
"""
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
return self._process_single_image(img)
|
||||||
|
elif isinstance(img, list):
|
||||||
|
return [self._process_single_image(single_img) for single_img in img]
|
||||||
|
else:
|
||||||
|
raise ValueError("Input must be a numpy array or a list of numpy arrays.")
|
||||||
|
|
||||||
|
def _process_single_image(self, single_img: np.ndarray) -> List[FacialAreaRegion]:
|
||||||
|
"""
|
||||||
|
Helper function to detect faces in a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
single_img (np.ndarray): pre-loaded image as numpy array
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
|
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
|
||||||
@ -53,7 +70,7 @@ class CenterFaceClient(Detector):
|
|||||||
# img, img.shape[0], img.shape[1], threshold=threshold
|
# img, img.shape[0], img.shape[1], threshold=threshold
|
||||||
# )
|
# )
|
||||||
detections, landmarks = self.build_model().forward(
|
detections, landmarks = self.build_model().forward(
|
||||||
img, img.shape[0], img.shape[1], threshold=threshold
|
single_img, single_img.shape[0], single_img.shape[1], threshold=threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, detection in enumerate(detections):
|
for i, detection in enumerate(detections):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user