Add support for batched input.

This commit is contained in:
h-alice 2025-01-22 16:54:51 +08:00
parent 95bb92c933
commit 6df7b7d8e9
No known key found for this signature in database
GPG Key ID: 5708F34144A70909

View File

@ -174,7 +174,7 @@ def analyze(
expand_percentage: int = 0,
silent: bool = False,
anti_spoofing: bool = False,
) -> List[Dict[str, Any]]:
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
Analyze facial attributes such as age, gender, emotion, and race in the provided image.
Args:
@ -206,7 +206,10 @@ def analyze(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
(List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
explained below.
(List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
the analysis results for a detected face. Each dictionary in the list contains the
following keys:
@ -253,6 +256,29 @@ def analyze(
- 'middle eastern': Confidence score for Middle Eastern ethnicity.
- 'white': Confidence score for White ethnicity.
"""
if isinstance(img_path, np.ndarray) and len(img_path.shape) == 4:
# Received 4-D array, which means image batch.
# Check batch dimension and process each image separately.
if img_path.shape[0] > 1:
batch_resp_obj = []
# Execute analysis for each image in the batch.
for single_img in img_path:
resp_obj = demography.analyze(
img_path=single_img,
actions=actions,
enforce_detection=enforce_detection,
detector_backend=detector_backend,
align=align,
expand_percentage=expand_percentage,
silent=silent,
anti_spoofing=anti_spoofing,
)
# Append the response object to the batch response list.
batch_resp_obj.append(resp_obj)
return batch_resp_obj
return demography.analyze(
img_path=img_path,
actions=actions,