diff --git a/deepface/modules/preprocessing.py b/deepface/modules/preprocessing.py index 7daf248..23edb5d 100644 --- a/deepface/modules/preprocessing.py +++ b/deepface/modules/preprocessing.py @@ -1,7 +1,9 @@ +# built-in dependencies import os from typing import Union, Tuple import base64 from pathlib import Path +import imghdr # 3rd party import numpy as np @@ -82,16 +84,16 @@ def load_base64(uri: str) -> np.ndarray: if len(encoded_data_parts) < 2: raise ValueError("format error in base64 encoded string") - # similar to find functionality, we are just considering these extensions - if not ( - uri.startswith("data:image/jpeg") - or uri.startswith("data:image/jpg") - or uri.startswith("data:image/png") - ): - raise ValueError(f"input image can be jpg, jpeg or png, but it is {encoded_data_parts}") - encoded_data = encoded_data_parts[1] - nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8) + decoded_bytes = base64.b64decode(encoded_data) + file_type = imghdr.what(None, h=decoded_bytes) + + # similar to find functionality, we are just considering these extensions + # content type is safer option than file extension + if file_type not in ["jpeg", "png"]: + raise ValueError(f"input image can be jpg or png, but it is {file_type}") + + nparr = np.fromstring(decoded_bytes, np.uint8) img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) return img_bgr diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index a902419..18e1a93 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -3,6 +3,7 @@ import os import pickle from typing import List, Union, Optional, Dict, Any import time +import imghdr # 3rd party dependencies import numpy as np @@ -296,8 +297,9 @@ def __list_images(path: str) -> List[str]: images = [] for r, _, f in os.walk(path): for file in f: - if file.lower().endswith((".jpg", ".jpeg", ".png")): - exact_path = os.path.join(r, file) + exact_path = os.path.join(r, file) + file_type = imghdr.what(exact_path) + if file_type in ["jpeg", "png"]: images.append(exact_path) return images diff --git a/tests/test_extract_faces.py b/tests/test_extract_faces.py index b8119fd..7e080de 100644 --- a/tests/test_extract_faces.py +++ b/tests/test_extract_faces.py @@ -1,6 +1,8 @@ import numpy as np +import base64 import pytest from deepface import DeepFace +from deepface.modules import preprocessing from deepface.commons.logger import Logger logger = Logger("tests/test_extract_faces.py") @@ -48,3 +50,24 @@ def test_backends_for_not_enforced_detection_with_non_facial_inputs(): ) assert objs[0]["face"].shape == (224, 224, 3) logger.info("✅ extract_faces for not enforced detection and non-facial image test is done") + + +def test_file_types_while_loading_base64(): + img1_path = "dataset/img47.jpg" + img1_base64 = image_to_base64(image_path=img1_path) + + with pytest.raises(ValueError, match="input image can be jpg or png, but it is"): + _ = preprocessing.load_base64(uri=img1_base64) + + img2_path = "dataset/img1.jpg" + img2_base64 = image_to_base64(image_path=img2_path) + + img2 = preprocessing.load_base64(uri=img2_base64) + # 3 dimensional image should be loaded + assert len(img2.shape) == 3 + + +def image_to_base64(image_path): + with open(image_path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + return "data:image/jpeg," + encoded_string diff --git a/tests/test_find.py b/tests/test_find.py index 5218724..5443899 100644 --- a/tests/test_find.py +++ b/tests/test_find.py @@ -3,6 +3,7 @@ import cv2 import pandas as pd from deepface import DeepFace from deepface.modules import verification +from deepface.modules import recognition from deepface.commons.logger import Logger logger = Logger("tests/test_find.py") @@ -11,7 +12,7 @@ threshold = verification.find_threshold(model_name="VGG-Face", distance_metric=" def test_find_with_exact_path(): - img_path = os.path.join("dataset","img1.jpg") + img_path = os.path.join("dataset", "img1.jpg") dfs = DeepFace.find(img_path=img_path, db_path="dataset", silent=True) assert len(dfs) > 0 for df in dfs: @@ -31,7 +32,7 @@ def test_find_with_exact_path(): def test_find_with_array_input(): - img_path = os.path.join("dataset","img1.jpg") + img_path = os.path.join("dataset", "img1.jpg") img1 = cv2.imread(img_path) dfs = DeepFace.find(img1, db_path="dataset", silent=True) assert len(dfs) > 0 @@ -53,7 +54,7 @@ def test_find_with_array_input(): def test_find_with_extracted_faces(): - img_path = os.path.join("dataset","img1.jpg") + img_path = os.path.join("dataset", "img1.jpg") face_objs = DeepFace.extract_faces(img_path) img = face_objs[0]["face"] dfs = DeepFace.find(img, db_path="dataset", detector_backend="skip", silent=True) @@ -72,3 +73,25 @@ def test_find_with_extracted_faces(): logger.debug(df.head()) assert df.shape[0] > 0 logger.info("✅ test find for extracted face input done") + + +def test_filetype_for_find(): + """ + only images as jpg and png can be loaded into database + """ + img_path = os.path.join("dataset", "img1.jpg") + dfs = DeepFace.find(img_path=img_path, db_path="dataset", silent=True) + + df = dfs[0] + + # img47 is webp even though its extension is jpg + assert df[df["identity"] == "dataset/img47.jpg"].shape[0] == 0 + + +def test_filetype_for_find_bulk_embeddings(): + imgs = recognition.__list_images("dataset") + + assert len(imgs) > 0 + + # img47 is webp even though its extension is jpg + assert "dataset/img47.jpg" not in imgs