patch: Lint

This commit is contained in:
h-alice 2025-01-16 17:32:45 +08:00
parent 6a7bbdb926
commit 6a8d1d95d3
No known key found for this signature in database
GPG Key ID: 5708F34144A70909

View File

@ -49,7 +49,8 @@ class ApparentAgeClient(Demography):
List of images as List[np.ndarray] or List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3) Batch of images as np.ndarray (n, 224, 224, 3)
Returns: Returns:
np.ndarray (age_classes,) if single image, np.ndarray (n, age_classes) if batched images. np.ndarray (age_classes,) if single image,
np.ndarray (n, age_classes) if batched images.
""" """
# Preprocessing input image or image list. # Preprocessing input image or image list.
imgs = self._preprocess_batch_or_single_input(img) imgs = self._preprocess_batch_or_single_input(img)
@ -60,8 +61,9 @@ class ApparentAgeClient(Demography):
# Calculate apparent ages # Calculate apparent ages
if len(age_predictions.shape) == 1: # Single prediction list if len(age_predictions.shape) == 1: # Single prediction list
return find_apparent_age(age_predictions) return find_apparent_age(age_predictions)
else: # Batched predictions
return np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions]) return np.array([
find_apparent_age(age_prediction) for age_prediction in age_predictions])
def load_model( def load_model(
@ -106,8 +108,8 @@ def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
Returns: Returns:
apparent_age (float) apparent_age (float)
""" """
assert len(age_predictions.shape) == 1, "Input should be a list of age predictions, \ assert len(age_predictions.shape) == 1, f"Input should be a list of predictions, \
not batched. Got shape: {}".format(age_predictions.shape) not batched. Got shape: {age_predictions.shape}"
output_indexes = np.arange(0, 101) output_indexes = np.arange(0, 101)
apparent_age = np.sum(age_predictions * output_indexes) apparent_age = np.sum(age_predictions * output_indexes)
return apparent_age return apparent_age