Merge pull request #688 from Vincent-Stragier/master

Close #532, add docstring and some refactoring.
This commit is contained in:
Sefik Ilkin Serengil 2023-03-01 15:22:17 +01:00 committed by GitHub
commit 05f309f357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 119 additions and 42 deletions

28
.vscode/settings.json vendored
View File

@ -1,17 +1,13 @@
{ {
"python.linting.pylintEnabled": true, "python.linting.pylintEnabled": true,
"python.linting.enabled": true, "python.linting.enabled": true,
"python.linting.pylintUseMinimalCheckers": false, "python.linting.pylintUseMinimalCheckers": false,
"editor.formatOnSave": true, "editor.formatOnSave": true,
"editor.renderWhitespace": "all", "editor.renderWhitespace": "all",
"files.autoSave": "afterDelay", "files.autoSave": "afterDelay",
"python.analysis.typeCheckingMode": "basic", "python.analysis.typeCheckingMode": "basic",
"python.formatting.provider": "black", "python.formatting.provider": "black",
"python.formatting.blackArgs": [ "python.formatting.blackArgs": ["--line-length=100"],
"--line-length=100" "editor.fontWeight": "normal",
], "python.analysis.extraPaths": ["./deepface"]
"editor.fontWeight": "normal", }
"python.analysis.extraPaths": [
"./deepface"
]
}

View File

@ -30,6 +30,11 @@ elif tf_major_version == 2:
def initialize_folder(): def initialize_folder():
"""Initialize the folder for storing weights and models.
Raises:
OSError: if the folder cannot be created.
"""
home = get_deepface_home() home = get_deepface_home()
if not os.path.exists(home + "/.deepface"): if not os.path.exists(home + "/.deepface"):
@ -42,6 +47,11 @@ def initialize_folder():
def get_deepface_home(): def get_deepface_home():
"""Get the home directory for storing weights and models.
Returns:
str: the home directory.
"""
return str(os.getenv("DEEPFACE_HOME", default=str(Path.home()))) return str(os.getenv("DEEPFACE_HOME", default=str(Path.home())))
@ -49,6 +59,14 @@ def get_deepface_home():
def loadBase64Img(uri): def loadBase64Img(uri):
"""Load image from base64 string.
Args:
uri: a base64 string.
Returns:
numpy array: the loaded image.
"""
encoded_data = uri.split(",")[1] encoded_data = uri.split(",")[1]
nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8) nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
@ -56,34 +74,36 @@ def loadBase64Img(uri):
def load_image(img): def load_image(img):
exact_image = False """Load image from path, url, base64 or numpy array.
base64_img = False
url_img = False
Args:
img: a path, url, base64 or numpy array.
Raises:
ValueError: if the image path does not exist.
Returns:
numpy array: the loaded image.
"""
# The image is already a numpy array
if type(img).__module__ == np.__name__: if type(img).__module__ == np.__name__:
exact_image = True return img
elif img.startswith("data:image/"): # The image is a base64 string
base64_img = True if img.startswith("data:image/"):
return loadBase64Img(img)
elif img.startswith("http"): # The image is a url
url_img = True if img.startswith("http"):
return np.array(Image.open(requests.get(img, stream=True, timeout=60).raw).convert("RGB"))[
:, :, ::-1
]
# --------------------------- # The image is a path
if os.path.isfile(img) is not True:
raise ValueError(f"Confirm that {img} exists")
if base64_img is True: return cv2.imread(img)
img = loadBase64Img(img)
elif url_img is True:
img = np.array(Image.open(requests.get(img, stream=True, timeout=60).raw).convert("RGB"))
elif exact_image is not True: # image path passed as input
if os.path.isfile(img) is not True:
raise ValueError(f"Confirm that {img} exists")
img = cv2.imread(img)
return img
# -------------------------------------------------- # --------------------------------------------------
@ -97,6 +117,24 @@ def extract_faces(
enforce_detection=True, enforce_detection=True,
align=True, align=True,
): ):
"""Extract faces from an image.
Args:
img: a path, url, base64 or numpy array.
target_size (tuple, optional): the target size of the extracted faces.
Defaults to (224, 224).
detector_backend (str, optional): the face detector backend. Defaults to "opencv".
grayscale (bool, optional): whether to convert the extracted faces to grayscale.
Defaults to False.
enforce_detection (bool, optional): whether to enforce face detection. Defaults to True.
align (bool, optional): whether to align the extracted faces. Defaults to True.
Raises:
ValueError: if face could not be detected and enforce_detection is True.
Returns:
list: a list of extracted faces.
"""
# this is going to store a list of img itself (numpy), it region and confidence # this is going to store a list of img itself (numpy), it region and confidence
extracted_faces = [] extracted_faces = []
@ -123,7 +161,6 @@ def extract_faces(
for current_img, current_region, confidence in face_objs: for current_img, current_region, confidence in face_objs:
if current_img.shape[0] > 0 and current_img.shape[1] > 0: if current_img.shape[0] > 0 and current_img.shape[1] > 0:
if grayscale is True: if grayscale is True:
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY) current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
@ -133,7 +170,10 @@ def extract_faces(
factor_1 = target_size[1] / current_img.shape[1] factor_1 = target_size[1] / current_img.shape[1]
factor = min(factor_0, factor_1) factor = min(factor_0, factor_1)
dsize = (int(current_img.shape[1] * factor), int(current_img.shape[0] * factor)) dsize = (
int(current_img.shape[1] * factor),
int(current_img.shape[0] * factor),
)
current_img = cv2.resize(current_img, dsize) current_img = cv2.resize(current_img, dsize)
diff_0 = target_size[0] - current_img.shape[0] diff_0 = target_size[0] - current_img.shape[0]
@ -152,7 +192,10 @@ def extract_faces(
else: else:
current_img = np.pad( current_img = np.pad(
current_img, current_img,
((diff_0 // 2, diff_0 - diff_0 // 2), (diff_1 // 2, diff_1 - diff_1 // 2)), (
(diff_0 // 2, diff_0 - diff_0 // 2),
(diff_1 // 2, diff_1 - diff_1 // 2),
),
"constant", "constant",
) )
@ -161,7 +204,8 @@ def extract_faces(
current_img = cv2.resize(current_img, target_size) current_img = cv2.resize(current_img, target_size)
# normalizing the image pixels # normalizing the image pixels
img_pixels = image.img_to_array(current_img) # what this line doing? must? # what this line doing? must?
img_pixels = image.img_to_array(current_img)
img_pixels = np.expand_dims(img_pixels, axis=0) img_pixels = np.expand_dims(img_pixels, axis=0)
img_pixels /= 255 # normalize input in [0, 1] img_pixels /= 255 # normalize input in [0, 1]
@ -185,6 +229,16 @@ def extract_faces(
def normalize_input(img, normalization="base"): def normalize_input(img, normalization="base"):
"""Normalize input image.
Args:
img (numpy array): the input image.
normalization (str, optional): the normalization technique. Defaults to "base",
for no normalization.
Returns:
numpy array: the normalized image.
"""
# issue 131 declares that some normalization techniques improves the accuracy # issue 131 declares that some normalization techniques improves the accuracy
@ -233,6 +287,14 @@ def normalize_input(img, normalization="base"):
def find_target_size(model_name): def find_target_size(model_name):
"""Find the target size of the model.
Args:
model_name (str): the model name.
Returns:
tuple: the target size.
"""
target_sizes = { target_sizes = {
"VGG-Face": (224, 224), "VGG-Face": (224, 224),
@ -267,6 +329,25 @@ def preprocess_face(
enforce_detection=True, enforce_detection=True,
align=True, align=True,
): ):
"""Preprocess face.
Args:
img (numpy array): the input image.
target_size (tuple, optional): the target size. Defaults to (224, 224).
detector_backend (str, optional): the detector backend. Defaults to "opencv".
grayscale (bool, optional): whether to convert to grayscale. Defaults to False.
enforce_detection (bool, optional): whether to enforce face detection. Defaults to True.
align (bool, optional): whether to align the face. Defaults to True.
Returns:
numpy array: the preprocessed face.
Raises:
ValueError: if face is not detected and enforce_detection is True.
Deprecated:
0.0.78: Use extract_faces instead of preprocess_face.
"""
print("⚠️ Function preprocess_face is deprecated. Use extract_faces instead.") print("⚠️ Function preprocess_face is deprecated. Use extract_faces instead.")
result = None result = None
img_objs = extract_faces( img_objs = extract_faces(