patch: Lint

This commit is contained in:
h-alice 2025-01-14 09:12:35 +08:00
parent a23893a5fa
commit da4a0c5452
No known key found for this signature in database
GPG Key ID: 5708F34144A70909

View File

@ -38,14 +38,14 @@ class Demography(ABC):
if not self.model_name: # Check if called from derived class if not self.model_name: # Check if called from derived class
raise NotImplementedError("no model selected") raise NotImplementedError("no model selected")
assert img_batch.ndim == 4, "expected 4-dimensional tensor input" assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
if img_batch.shape[0] == 1: # Single image if img_batch.shape[0] == 1: # Single image
# Predict with legacy method. # Predict with legacy method.
return self.model(img_batch, training=False).numpy()[0, :] return self.model(img_batch, training=False).numpy()[0, :]
else:
# Batch of images # Batch of images
# Predict with batch prediction # Predict with batch prediction
return self.model.predict_on_batch(img_batch) return self.model.predict_on_batch(img_batch)
def _preprocess_batch_or_single_input( def _preprocess_batch_or_single_input(
self, self,