[fix] handle between grayscale and RGB image for models

This commit is contained in:
NatLee 2025-01-07 04:55:14 +08:00
parent 29141b3cd5
commit 36fb512bec

View File

@ -36,7 +36,8 @@ class Demography(ABC):
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
if img_batch.shape[0] == 1: # Single image
img_batch = img_batch.squeeze(0) # Remove batch dimension
if img_batch.shape[-1] != 3: # Check if grayscale
img_batch = img_batch.squeeze(0) # Remove batch dimension
predict_result = self.model(img_batch, training=False).numpy()[0, :] # Predict with legacy method.
return predict_result
else: # Batch of images