patch: Greyscale image prediction condition.

This commit is contained in:
h-alice 2025-01-13 23:14:40 +08:00
parent 8883b212b2
commit fa4044adae
No known key found for this signature in database
GPG Key ID: 5708F34144A70909

View File

@ -38,14 +38,15 @@ class Demography(ABC):
if not self.model_name: # Check if called from derived class
raise NotImplementedError("no model selected")
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
# Single image
if img_batch.shape[0] == 1:
if img_batch.shape[-1] != 3: # Handle grayscale image, check last dimension.
# Check if grayscale by checking last dimension, if not 3, it is grayscale.
if img_batch.shape[-1] != 3:
# Remove batch dimension
img_batch = img_batch.squeeze(0)
img_batch = img_batch.squeeze(0) # Remove batch dimension
if img_batch.shape[0] == 1: # Single image
# Predict with legacy method.
return self.model(img_batch, training=False).numpy()[0, :]
# Batch of images
# Predict with batch prediction
return self.model.predict_on_batch(img_batch)