pseudo-batching centerface

This commit is contained in:
galthran-wq 2025-02-18 09:47:06 +00:00
parent 1d358aa15a
commit f5188c802c

View File

@ -1,6 +1,6 @@
# built-in dependencies
import os
from typing import List
from typing import List, Union
# 3rd party dependencies
import numpy as np
@ -34,12 +34,29 @@ class CenterFaceClient(Detector):
return CenterFace(weight_path=weights_path)
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]:
def detect_faces(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]:
"""
Detect and align face with CenterFace
Args:
img (np.ndarray): pre-loaded image as numpy array
img (Union[np.ndarray, List[np.ndarray]]): pre-loaded image as numpy array or a list of those
Returns:
results (Union[List[FacialAreaRegion], List[List[FacialAreaRegion]]]): A list or a list of lists of FacialAreaRegion objects
"""
if isinstance(img, np.ndarray):
return self._process_single_image(img)
elif isinstance(img, list):
return [self._process_single_image(single_img) for single_img in img]
else:
raise ValueError("Input must be a numpy array or a list of numpy arrays.")
def _process_single_image(self, single_img: np.ndarray) -> List[FacialAreaRegion]:
"""
Helper function to detect faces in a single image.
Args:
single_img (np.ndarray): pre-loaded image as numpy array
Returns:
results (List[FacialAreaRegion]): A list of FacialAreaRegion objects
@ -53,7 +70,7 @@ class CenterFaceClient(Detector):
# img, img.shape[0], img.shape[1], threshold=threshold
# )
detections, landmarks = self.build_model().forward(
img, img.shape[0], img.shape[1], threshold=threshold
single_img, single_img.shape[0], single_img.shape[1], threshold=threshold
)
for i, detection in enumerate(detections):