patch: Lint

This commit is contained in:
h-alice 2025-01-14 09:12:35 +08:00
parent a23893a5fa
commit da4a0c5452
No known key found for this signature in database
GPG Key ID: 5708F34144A70909

View File

@ -38,14 +38,14 @@ class Demography(ABC):
if not self.model_name: # Check if called from derived class
raise NotImplementedError("no model selected")
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
if img_batch.shape[0] == 1: # Single image
# Predict with legacy method.
return self.model(img_batch, training=False).numpy()[0, :]
else:
# Batch of images
# Predict with batch prediction
return self.model.predict_on_batch(img_batch)
# Batch of images
# Predict with batch prediction
return self.model.predict_on_batch(img_batch)
def _preprocess_batch_or_single_input(
self,