diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index afc54ea..9740fb0 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -657,6 +657,13 @@ def represent( img = img_path.copy() else: raise ValueError(f"unexpected type for img_path - {type(img_path)}") + # -------------------------------- + if len(img.shape) == 4: + img = img[0] # e.g. (1, 224, 224, 3) to (224, 224, 3) + if len(img.shape) == 3: + img = cv2.resize(img, target_size) + img = np.expand_dims(img, axis=0) + # -------------------------------- img_region = [0, 0, img.shape[1], img.shape[0]] img_objs = [(img, img_region, 0)] # ---------------------------------