patch: fix dimension

This commit is contained in:
h-alice 2025-01-13 23:35:48 +08:00
parent 910d6e1d80
commit 72b6db19d6
No known key found for this signature in database
GPG Key ID: 5708F34144A70909

View File

@ -40,7 +40,6 @@ class Demography(ABC):
assert img_batch.ndim == 4, "expected 4-dimensional tensor input" assert img_batch.ndim == 4, "expected 4-dimensional tensor input"
if img_batch.shape[0] == 1: # Single image if img_batch.shape[0] == 1: # Single image
img_batch = img_batch.squeeze(0) # Remove batch dimension
# Predict with legacy method. # Predict with legacy method.
return self.model(img_batch, training=False).numpy()[0, :] return self.model(img_batch, training=False).numpy()[0, :]
else: else: