mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 19:45:21 +00:00
resize functionality moved to represent module
we were handling resizing in extract faces. with this commit we moved it to representation module to provide seperation of concern.
This commit is contained in:
parent
42ee2982f0
commit
1078be9f12
@ -2,7 +2,7 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
from typing import Any, Dict, List, Union, Optional
|
||||||
|
|
||||||
# this has to be set before importing tensorflow
|
# this has to be set before importing tensorflow
|
||||||
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
||||||
@ -439,7 +439,6 @@ def stream(
|
|||||||
|
|
||||||
def extract_faces(
|
def extract_faces(
|
||||||
img_path: Union[str, np.ndarray],
|
img_path: Union[str, np.ndarray],
|
||||||
target_size: Optional[Tuple[int, int]] = (224, 224),
|
|
||||||
detector_backend: str = "opencv",
|
detector_backend: str = "opencv",
|
||||||
enforce_detection: bool = True,
|
enforce_detection: bool = True,
|
||||||
align: bool = True,
|
align: bool = True,
|
||||||
@ -453,9 +452,6 @@ def extract_faces(
|
|||||||
img_path (str or np.ndarray): Path to the first image. Accepts exact image path
|
img_path (str or np.ndarray): Path to the first image. Accepts exact image path
|
||||||
as a string, numpy array (BGR), or base64 encoded images.
|
as a string, numpy array (BGR), or base64 encoded images.
|
||||||
|
|
||||||
target_size (tuple): final shape of facial image. black pixels will be
|
|
||||||
added to resize the image (default is (224, 224)).
|
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
|
||||||
|
|
||||||
@ -485,13 +481,11 @@ def extract_faces(
|
|||||||
|
|
||||||
return detection.extract_faces(
|
return detection.extract_faces(
|
||||||
img_path=img_path,
|
img_path=img_path,
|
||||||
target_size=target_size,
|
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
align=align,
|
align=align,
|
||||||
expand_percentage=expand_percentage,
|
expand_percentage=expand_percentage,
|
||||||
grayscale=grayscale,
|
grayscale=grayscale,
|
||||||
human_readable=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface.modules import modeling, detection
|
from deepface.modules import modeling, detection, preprocessing
|
||||||
from deepface.extendedmodels import Gender, Race, Emotion
|
from deepface.extendedmodels import Gender, Race, Emotion
|
||||||
|
|
||||||
|
|
||||||
@ -118,7 +118,6 @@ def analyze(
|
|||||||
|
|
||||||
img_objs = detection.extract_faces(
|
img_objs = detection.extract_faces(
|
||||||
img_path=img_path,
|
img_path=img_path,
|
||||||
target_size=(224, 224),
|
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
grayscale=False,
|
grayscale=False,
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
@ -130,60 +129,68 @@ def analyze(
|
|||||||
img_content = img_obj["face"]
|
img_content = img_obj["face"]
|
||||||
img_region = img_obj["facial_area"]
|
img_region = img_obj["facial_area"]
|
||||||
img_confidence = img_obj["confidence"]
|
img_confidence = img_obj["confidence"]
|
||||||
if img_content.shape[0] > 0 and img_content.shape[1] > 0:
|
if img_content.shape[0] == 0 or img_content.shape[1] == 0:
|
||||||
obj = {}
|
continue
|
||||||
# facial attribute analysis
|
|
||||||
pbar = tqdm(
|
|
||||||
range(0, len(actions)),
|
|
||||||
desc="Finding actions",
|
|
||||||
disable=silent if len(actions) > 1 else True,
|
|
||||||
)
|
|
||||||
for index in pbar:
|
|
||||||
action = actions[index]
|
|
||||||
pbar.set_description(f"Action: {action}")
|
|
||||||
|
|
||||||
if action == "emotion":
|
# rgb to bgr
|
||||||
emotion_predictions = modeling.build_model("Emotion").predict(img_content)
|
img_content = img_content[:, :, ::-1]
|
||||||
sum_of_predictions = emotion_predictions.sum()
|
|
||||||
|
|
||||||
obj["emotion"] = {}
|
# resize input image
|
||||||
for i, emotion_label in enumerate(Emotion.labels):
|
img_content = preprocessing.resize_image(img=img_content, target_size=(224, 224))
|
||||||
emotion_prediction = 100 * emotion_predictions[i] / sum_of_predictions
|
|
||||||
obj["emotion"][emotion_label] = emotion_prediction
|
|
||||||
|
|
||||||
obj["dominant_emotion"] = Emotion.labels[np.argmax(emotion_predictions)]
|
obj = {}
|
||||||
|
# facial attribute analysis
|
||||||
|
pbar = tqdm(
|
||||||
|
range(0, len(actions)),
|
||||||
|
desc="Finding actions",
|
||||||
|
disable=silent if len(actions) > 1 else True,
|
||||||
|
)
|
||||||
|
for index in pbar:
|
||||||
|
action = actions[index]
|
||||||
|
pbar.set_description(f"Action: {action}")
|
||||||
|
|
||||||
elif action == "age":
|
if action == "emotion":
|
||||||
apparent_age = modeling.build_model("Age").predict(img_content)
|
emotion_predictions = modeling.build_model("Emotion").predict(img_content)
|
||||||
# int cast is for exception - object of type 'float32' is not JSON serializable
|
sum_of_predictions = emotion_predictions.sum()
|
||||||
obj["age"] = int(apparent_age)
|
|
||||||
|
|
||||||
elif action == "gender":
|
obj["emotion"] = {}
|
||||||
gender_predictions = modeling.build_model("Gender").predict(img_content)
|
for i, emotion_label in enumerate(Emotion.labels):
|
||||||
obj["gender"] = {}
|
emotion_prediction = 100 * emotion_predictions[i] / sum_of_predictions
|
||||||
for i, gender_label in enumerate(Gender.labels):
|
obj["emotion"][emotion_label] = emotion_prediction
|
||||||
gender_prediction = 100 * gender_predictions[i]
|
|
||||||
obj["gender"][gender_label] = gender_prediction
|
|
||||||
|
|
||||||
obj["dominant_gender"] = Gender.labels[np.argmax(gender_predictions)]
|
obj["dominant_emotion"] = Emotion.labels[np.argmax(emotion_predictions)]
|
||||||
|
|
||||||
elif action == "race":
|
elif action == "age":
|
||||||
race_predictions = modeling.build_model("Race").predict(img_content)
|
apparent_age = modeling.build_model("Age").predict(img_content)
|
||||||
sum_of_predictions = race_predictions.sum()
|
# int cast is for exception - object of type 'float32' is not JSON serializable
|
||||||
|
obj["age"] = int(apparent_age)
|
||||||
|
|
||||||
obj["race"] = {}
|
elif action == "gender":
|
||||||
for i, race_label in enumerate(Race.labels):
|
gender_predictions = modeling.build_model("Gender").predict(img_content)
|
||||||
race_prediction = 100 * race_predictions[i] / sum_of_predictions
|
obj["gender"] = {}
|
||||||
obj["race"][race_label] = race_prediction
|
for i, gender_label in enumerate(Gender.labels):
|
||||||
|
gender_prediction = 100 * gender_predictions[i]
|
||||||
|
obj["gender"][gender_label] = gender_prediction
|
||||||
|
|
||||||
obj["dominant_race"] = Race.labels[np.argmax(race_predictions)]
|
obj["dominant_gender"] = Gender.labels[np.argmax(gender_predictions)]
|
||||||
|
|
||||||
# -----------------------------
|
elif action == "race":
|
||||||
# mention facial areas
|
race_predictions = modeling.build_model("Race").predict(img_content)
|
||||||
obj["region"] = img_region
|
sum_of_predictions = race_predictions.sum()
|
||||||
# include image confidence
|
|
||||||
obj["face_confidence"] = img_confidence
|
|
||||||
|
|
||||||
resp_objects.append(obj)
|
obj["race"] = {}
|
||||||
|
for i, race_label in enumerate(Race.labels):
|
||||||
|
race_prediction = 100 * race_predictions[i] / sum_of_predictions
|
||||||
|
obj["race"][race_label] = race_prediction
|
||||||
|
|
||||||
|
obj["dominant_race"] = Race.labels[np.argmax(race_predictions)]
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# mention facial areas
|
||||||
|
obj["region"] = img_region
|
||||||
|
# include image confidence
|
||||||
|
obj["face_confidence"] = img_confidence
|
||||||
|
|
||||||
|
resp_objects.append(obj)
|
||||||
|
|
||||||
return resp_objects
|
return resp_objects
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# built-in dependencies
|
# built-in dependencies
|
||||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
# 3rd part dependencies
|
# 3rd part dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,7 +10,6 @@ from PIL import Image
|
|||||||
from deepface.modules import preprocessing
|
from deepface.modules import preprocessing
|
||||||
from deepface.models.Detector import DetectedFace, FacialAreaRegion
|
from deepface.models.Detector import DetectedFace, FacialAreaRegion
|
||||||
from deepface.detectors import DetectorWrapper
|
from deepface.detectors import DetectorWrapper
|
||||||
from deepface.commons import package_utils
|
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
logger = Logger(module="deepface/modules/detection.py")
|
logger = Logger(module="deepface/modules/detection.py")
|
||||||
@ -18,22 +17,13 @@ logger = Logger(module="deepface/modules/detection.py")
|
|||||||
# pylint: disable=no-else-raise
|
# pylint: disable=no-else-raise
|
||||||
|
|
||||||
|
|
||||||
tf_major_version = package_utils.get_tf_major_version()
|
|
||||||
if tf_major_version == 1:
|
|
||||||
from keras.preprocessing import image
|
|
||||||
elif tf_major_version == 2:
|
|
||||||
from tensorflow.keras.preprocessing import image
|
|
||||||
|
|
||||||
|
|
||||||
def extract_faces(
|
def extract_faces(
|
||||||
img_path: Union[str, np.ndarray],
|
img_path: Union[str, np.ndarray],
|
||||||
target_size: Optional[Tuple[int, int]] = (224, 224),
|
|
||||||
detector_backend: str = "opencv",
|
detector_backend: str = "opencv",
|
||||||
enforce_detection: bool = True,
|
enforce_detection: bool = True,
|
||||||
align: bool = True,
|
align: bool = True,
|
||||||
expand_percentage: int = 0,
|
expand_percentage: int = 0,
|
||||||
grayscale: bool = False,
|
grayscale: bool = False,
|
||||||
human_readable=False,
|
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Extract faces from a given image
|
Extract faces from a given image
|
||||||
@ -42,9 +32,6 @@ def extract_faces(
|
|||||||
img_path (str or np.ndarray): Path to the first image. Accepts exact image path
|
img_path (str or np.ndarray): Path to the first image. Accepts exact image path
|
||||||
as a string, numpy array (BGR), or base64 encoded images.
|
as a string, numpy array (BGR), or base64 encoded images.
|
||||||
|
|
||||||
target_size (tuple): final shape of facial image. black pixels will be
|
|
||||||
added to resize the image.
|
|
||||||
|
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv)
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv)
|
||||||
|
|
||||||
@ -58,13 +45,10 @@ def extract_faces(
|
|||||||
grayscale (boolean): Flag to convert the image to grayscale before
|
grayscale (boolean): Flag to convert the image to grayscale before
|
||||||
processing (default is False).
|
processing (default is False).
|
||||||
|
|
||||||
human_readable (bool): Flag to make the image human readable. 3D RGB for human readable
|
|
||||||
or 4D BGR for ML models (default is False).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
|
results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary contains:
|
||||||
|
|
||||||
- "face" (np.ndarray): The detected face as a NumPy array.
|
- "face" (np.ndarray): The detected face as a NumPy array in RGB format.
|
||||||
|
|
||||||
- "facial_area" (Dict[str, Any]): The detected face's regions as a dictionary containing:
|
- "facial_area" (Dict[str, Any]): The detected face's regions as a dictionary containing:
|
||||||
- keys 'x', 'y', 'w', 'h' with int values
|
- keys 'x', 'y', 'w', 'h' with int values
|
||||||
@ -122,57 +106,11 @@ def extract_faces(
|
|||||||
if grayscale is True:
|
if grayscale is True:
|
||||||
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
|
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
# resize and padding
|
current_img = current_img / 255 # normalize input in [0, 1]
|
||||||
if target_size is not None:
|
|
||||||
factor_0 = target_size[0] / current_img.shape[0]
|
|
||||||
factor_1 = target_size[1] / current_img.shape[1]
|
|
||||||
factor = min(factor_0, factor_1)
|
|
||||||
|
|
||||||
dsize = (
|
|
||||||
int(current_img.shape[1] * factor),
|
|
||||||
int(current_img.shape[0] * factor),
|
|
||||||
)
|
|
||||||
current_img = cv2.resize(current_img, dsize)
|
|
||||||
|
|
||||||
diff_0 = target_size[0] - current_img.shape[0]
|
|
||||||
diff_1 = target_size[1] - current_img.shape[1]
|
|
||||||
if grayscale is False:
|
|
||||||
# Put the base image in the middle of the padded image
|
|
||||||
current_img = np.pad(
|
|
||||||
current_img,
|
|
||||||
(
|
|
||||||
(diff_0 // 2, diff_0 - diff_0 // 2),
|
|
||||||
(diff_1 // 2, diff_1 - diff_1 // 2),
|
|
||||||
(0, 0),
|
|
||||||
),
|
|
||||||
"constant",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
current_img = np.pad(
|
|
||||||
current_img,
|
|
||||||
(
|
|
||||||
(diff_0 // 2, diff_0 - diff_0 // 2),
|
|
||||||
(diff_1 // 2, diff_1 - diff_1 // 2),
|
|
||||||
),
|
|
||||||
"constant",
|
|
||||||
)
|
|
||||||
|
|
||||||
# double check: if target image is not still the same size with target.
|
|
||||||
if current_img.shape[0:2] != target_size:
|
|
||||||
current_img = cv2.resize(current_img, target_size)
|
|
||||||
|
|
||||||
# normalizing the image pixels
|
|
||||||
# what this line doing? must?
|
|
||||||
img_pixels = image.img_to_array(current_img)
|
|
||||||
img_pixels = np.expand_dims(img_pixels, axis=0)
|
|
||||||
img_pixels /= 255 # normalize input in [0, 1]
|
|
||||||
# discard expanded dimension
|
|
||||||
if human_readable is True and len(img_pixels.shape) == 4:
|
|
||||||
img_pixels = img_pixels[0]
|
|
||||||
|
|
||||||
resp_objs.append(
|
resp_objs.append(
|
||||||
{
|
{
|
||||||
"face": img_pixels[:, :, ::-1] if human_readable is True else img_pixels,
|
"face": current_img[:, :, ::-1],
|
||||||
"facial_area": {
|
"facial_area": {
|
||||||
"x": int(current_region.x),
|
"x": int(current_region.x),
|
||||||
"y": int(current_region.y),
|
"y": int(current_region.y),
|
||||||
|
@ -11,6 +11,16 @@ import cv2
|
|||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
# project dependencies
|
||||||
|
from deepface.commons import package_utils
|
||||||
|
|
||||||
|
|
||||||
|
tf_major_version = package_utils.get_tf_major_version()
|
||||||
|
if tf_major_version == 1:
|
||||||
|
from keras.preprocessing import image
|
||||||
|
elif tf_major_version == 2:
|
||||||
|
from tensorflow.keras.preprocessing import image
|
||||||
|
|
||||||
|
|
||||||
def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
|
def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
|
||||||
"""
|
"""
|
||||||
@ -66,8 +76,8 @@ def load_image_from_web(url: str) -> np.ndarray:
|
|||||||
response = requests.get(url, stream=True, timeout=60)
|
response = requests.get(url, stream=True, timeout=60)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
image_array = np.asarray(bytearray(response.raw.read()), dtype=np.uint8)
|
image_array = np.asarray(bytearray(response.raw.read()), dtype=np.uint8)
|
||||||
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
img = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||||
return image
|
return img
|
||||||
|
|
||||||
|
|
||||||
def load_base64(uri: str) -> np.ndarray:
|
def load_base64(uri: str) -> np.ndarray:
|
||||||
@ -157,3 +167,50 @@ def normalize_input(img: np.ndarray, normalization: str = "base") -> np.ndarray:
|
|||||||
raise ValueError(f"unimplemented normalization type - {normalization}")
|
raise ValueError(f"unimplemented normalization type - {normalization}")
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(img: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image to expected size of a ml model with adding black pixels.
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): pre-loaded image as numpy array
|
||||||
|
target_size (tuple): input shape of ml model
|
||||||
|
Returns:
|
||||||
|
img (np.ndarray): resized input image
|
||||||
|
"""
|
||||||
|
factor_0 = target_size[0] / img.shape[0]
|
||||||
|
factor_1 = target_size[1] / img.shape[1]
|
||||||
|
factor = min(factor_0, factor_1)
|
||||||
|
|
||||||
|
dsize = (
|
||||||
|
int(img.shape[1] * factor),
|
||||||
|
int(img.shape[0] * factor),
|
||||||
|
)
|
||||||
|
img = cv2.resize(img, dsize)
|
||||||
|
|
||||||
|
diff_0 = target_size[0] - img.shape[0]
|
||||||
|
diff_1 = target_size[1] - img.shape[1]
|
||||||
|
|
||||||
|
# Put the base image in the middle of the padded image
|
||||||
|
img = np.pad(
|
||||||
|
img,
|
||||||
|
(
|
||||||
|
(diff_0 // 2, diff_0 - diff_0 // 2),
|
||||||
|
(diff_1 // 2, diff_1 - diff_1 // 2),
|
||||||
|
(0, 0),
|
||||||
|
),
|
||||||
|
"constant",
|
||||||
|
)
|
||||||
|
|
||||||
|
# double check: if target image is not still the same size with target.
|
||||||
|
if img.shape[0:2] != target_size:
|
||||||
|
img = cv2.resize(img, target_size)
|
||||||
|
|
||||||
|
# make it 4-dimensional how ML models expect
|
||||||
|
img = image.img_to_array(img)
|
||||||
|
img = np.expand_dims(img, axis=0)
|
||||||
|
|
||||||
|
if img.max() > 1:
|
||||||
|
img = (img.astype(np.float32) / 255.0).astype(np.float32)
|
||||||
|
|
||||||
|
return img
|
||||||
|
@ -13,8 +13,7 @@ from PIL import Image
|
|||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
from deepface.commons import package_utils
|
from deepface.commons import package_utils
|
||||||
from deepface.modules import representation, detection, modeling, verification
|
from deepface.modules import representation, detection, verification
|
||||||
from deepface.models.FacialRecognition import FacialRecognition
|
|
||||||
|
|
||||||
logger = Logger(module="deepface/modules/recognition.py")
|
logger = Logger(module="deepface/modules/recognition.py")
|
||||||
|
|
||||||
@ -90,15 +89,9 @@ def find(
|
|||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
# -------------------------------
|
|
||||||
if os.path.isdir(db_path) is not True:
|
if os.path.isdir(db_path) is not True:
|
||||||
raise ValueError("Passed db_path does not exist!")
|
raise ValueError("Passed db_path does not exist!")
|
||||||
|
|
||||||
model: FacialRecognition = modeling.build_model(model_name)
|
|
||||||
target_size = model.input_shape
|
|
||||||
|
|
||||||
# ---------------------------------------
|
|
||||||
|
|
||||||
file_name = f"ds_{model_name}_{detector_backend}_v2.pkl"
|
file_name = f"ds_{model_name}_{detector_backend}_v2.pkl"
|
||||||
file_name = file_name.replace("-", "").lower()
|
file_name = file_name.replace("-", "").lower()
|
||||||
datastore_path = os.path.join(db_path, file_name)
|
datastore_path = os.path.join(db_path, file_name)
|
||||||
@ -180,7 +173,6 @@ def find(
|
|||||||
representations += __find_bulk_embeddings(
|
representations += __find_bulk_embeddings(
|
||||||
employees=new_images,
|
employees=new_images,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
target_size=target_size,
|
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
align=align,
|
align=align,
|
||||||
@ -212,7 +204,6 @@ def find(
|
|||||||
# img path might have more than once face
|
# img path might have more than once face
|
||||||
source_objs = detection.extract_faces(
|
source_objs = detection.extract_faces(
|
||||||
img_path=img_path,
|
img_path=img_path,
|
||||||
target_size=target_size,
|
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
grayscale=False,
|
grayscale=False,
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
@ -314,7 +305,6 @@ def __list_images(path: str) -> List[str]:
|
|||||||
def __find_bulk_embeddings(
|
def __find_bulk_embeddings(
|
||||||
employees: List[str],
|
employees: List[str],
|
||||||
model_name: str = "VGG-Face",
|
model_name: str = "VGG-Face",
|
||||||
target_size: tuple = (224, 224),
|
|
||||||
detector_backend: str = "opencv",
|
detector_backend: str = "opencv",
|
||||||
enforce_detection: bool = True,
|
enforce_detection: bool = True,
|
||||||
align: bool = True,
|
align: bool = True,
|
||||||
@ -362,7 +352,6 @@ def __find_bulk_embeddings(
|
|||||||
try:
|
try:
|
||||||
img_objs = detection.extract_faces(
|
img_objs = detection.extract_faces(
|
||||||
img_path=employee,
|
img_path=employee,
|
||||||
target_size=target_size,
|
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
grayscale=False,
|
grayscale=False,
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
|
@ -3,7 +3,6 @@ from typing import Any, Dict, List, Union
|
|||||||
|
|
||||||
# 3rd party dependencies
|
# 3rd party dependencies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
|
||||||
|
|
||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface.modules import modeling, detection, preprocessing
|
from deepface.modules import modeling, detection, preprocessing
|
||||||
@ -67,7 +66,6 @@ def represent(
|
|||||||
if detector_backend != "skip":
|
if detector_backend != "skip":
|
||||||
img_objs = detection.extract_faces(
|
img_objs = detection.extract_faces(
|
||||||
img_path=img_path,
|
img_path=img_path,
|
||||||
target_size=(target_size[1], target_size[0]),
|
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
grayscale=False,
|
grayscale=False,
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
@ -77,16 +75,10 @@ def represent(
|
|||||||
else: # skip
|
else: # skip
|
||||||
# Try load. If load error, will raise exception internal
|
# Try load. If load error, will raise exception internal
|
||||||
img, _ = preprocessing.load_image(img_path)
|
img, _ = preprocessing.load_image(img_path)
|
||||||
# --------------------------------
|
|
||||||
if len(img.shape) == 4:
|
if len(img.shape) != 3:
|
||||||
img = img[0] # e.g. (1, 224, 224, 3) to (224, 224, 3)
|
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
|
||||||
if len(img.shape) == 3:
|
|
||||||
img = cv2.resize(img, target_size)
|
|
||||||
img = np.expand_dims(img, axis=0)
|
|
||||||
# when called from verify, this is already normalized. But needed when user given.
|
|
||||||
if img.max() > 1:
|
|
||||||
img = (img.astype(np.float32) / 255.0).astype(np.float32)
|
|
||||||
# --------------------------------
|
|
||||||
# make dummy region and confidence to keep compatibility with `extract_faces`
|
# make dummy region and confidence to keep compatibility with `extract_faces`
|
||||||
img_objs = [
|
img_objs = [
|
||||||
{
|
{
|
||||||
@ -99,8 +91,20 @@ def represent(
|
|||||||
|
|
||||||
for img_obj in img_objs:
|
for img_obj in img_objs:
|
||||||
img = img_obj["face"]
|
img = img_obj["face"]
|
||||||
|
|
||||||
|
# rgb to bgr
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
|
||||||
region = img_obj["facial_area"]
|
region = img_obj["facial_area"]
|
||||||
confidence = img_obj["confidence"]
|
confidence = img_obj["confidence"]
|
||||||
|
|
||||||
|
# resize to expected shape of ml model
|
||||||
|
img = preprocessing.resize_image(
|
||||||
|
img=img,
|
||||||
|
# thanks to DeepId (!)
|
||||||
|
target_size=(target_size[1], target_size[0]),
|
||||||
|
)
|
||||||
|
|
||||||
# custom normalization
|
# custom normalization
|
||||||
img = preprocessing.normalize_input(img=img, normalization=normalization)
|
img = preprocessing.normalize_input(img=img, normalization=normalization)
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ def analysis(
|
|||||||
"""
|
"""
|
||||||
# initialize models
|
# initialize models
|
||||||
build_demography_models(enable_face_analysis=enable_face_analysis)
|
build_demography_models(enable_face_analysis=enable_face_analysis)
|
||||||
target_size = build_facial_recognition_model(model_name=model_name)
|
build_facial_recognition_model(model_name=model_name)
|
||||||
# call a dummy find function for db_path once to create embeddings before starting webcam
|
# call a dummy find function for db_path once to create embeddings before starting webcam
|
||||||
_ = search_identity(
|
_ = search_identity(
|
||||||
detected_face=np.zeros([224, 224, 3]),
|
detected_face=np.zeros([224, 224, 3]),
|
||||||
@ -89,9 +89,7 @@ def analysis(
|
|||||||
|
|
||||||
faces_coordinates = []
|
faces_coordinates = []
|
||||||
if freeze is False:
|
if freeze is False:
|
||||||
faces_coordinates = grab_facial_areas(
|
faces_coordinates = grab_facial_areas(img=img, detector_backend=detector_backend)
|
||||||
img=img, detector_backend=detector_backend, target_size=target_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# we will pass img to analyze modules (identity, demography) and add some illustrations
|
# we will pass img to analyze modules (identity, demography) and add some illustrations
|
||||||
# that is why, we will not be able to extract detected face from img clearly
|
# that is why, we will not be able to extract detected face from img clearly
|
||||||
@ -156,7 +154,7 @@ def analysis(
|
|||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
def build_facial_recognition_model(model_name: str) -> tuple:
|
def build_facial_recognition_model(model_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Build facial recognition model
|
Build facial recognition model
|
||||||
Args:
|
Args:
|
||||||
@ -165,9 +163,8 @@ def build_facial_recognition_model(model_name: str) -> tuple:
|
|||||||
Returns
|
Returns
|
||||||
input_shape (tuple): input shape of given facial recognitio n model.
|
input_shape (tuple): input shape of given facial recognitio n model.
|
||||||
"""
|
"""
|
||||||
model: FacialRecognition = DeepFace.build_model(model_name=model_name)
|
_ = DeepFace.build_model(model_name=model_name)
|
||||||
logger.info(f"{model_name} is built")
|
logger.info(f"{model_name} is built")
|
||||||
return model.input_shape
|
|
||||||
|
|
||||||
|
|
||||||
def search_identity(
|
def search_identity(
|
||||||
@ -231,7 +228,6 @@ def search_identity(
|
|||||||
# load found identity image - extracted if possible
|
# load found identity image - extracted if possible
|
||||||
target_objs = DeepFace.extract_faces(
|
target_objs = DeepFace.extract_faces(
|
||||||
img_path=target_path,
|
img_path=target_path,
|
||||||
target_size=(IDENTIFIED_IMG_SIZE, IDENTIFIED_IMG_SIZE),
|
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
enforce_detection=False,
|
enforce_detection=False,
|
||||||
align=True,
|
align=True,
|
||||||
@ -243,6 +239,7 @@ def search_identity(
|
|||||||
# extract 1st item directly
|
# extract 1st item directly
|
||||||
target_obj = target_objs[0]
|
target_obj = target_objs[0]
|
||||||
target_img = target_obj["face"]
|
target_img = target_obj["face"]
|
||||||
|
target_img = cv2.resize(target_img, (IDENTIFIED_IMG_SIZE, IDENTIFIED_IMG_SIZE))
|
||||||
target_img *= 255
|
target_img *= 255
|
||||||
target_img = target_img[:, :, ::-1]
|
target_img = target_img[:, :, ::-1]
|
||||||
else:
|
else:
|
||||||
@ -346,7 +343,7 @@ def countdown_to_release(
|
|||||||
|
|
||||||
|
|
||||||
def grab_facial_areas(
|
def grab_facial_areas(
|
||||||
img: np.ndarray, detector_backend: str, target_size: Tuple[int, int], threshold: int = 130
|
img: np.ndarray, detector_backend: str, threshold: int = 130
|
||||||
) -> List[Tuple[int, int, int, int]]:
|
) -> List[Tuple[int, int, int, int]]:
|
||||||
"""
|
"""
|
||||||
Find facial area coordinates in the given image
|
Find facial area coordinates in the given image
|
||||||
@ -354,7 +351,6 @@ def grab_facial_areas(
|
|||||||
img (np.ndarray): image itself
|
img (np.ndarray): image itself
|
||||||
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
|
||||||
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
|
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8' (default is opencv).
|
||||||
target_size (tuple): input shape of the facial recognition model.
|
|
||||||
threshold (int): threshold for facial area, discard smaller ones
|
threshold (int): threshold for facial area, discard smaller ones
|
||||||
Returns
|
Returns
|
||||||
result (list): list of tuple with x, y, w and h coordinates
|
result (list): list of tuple with x, y, w and h coordinates
|
||||||
@ -363,7 +359,6 @@ def grab_facial_areas(
|
|||||||
face_objs = DeepFace.extract_faces(
|
face_objs = DeepFace.extract_faces(
|
||||||
img_path=img,
|
img_path=img,
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
target_size=target_size,
|
|
||||||
# you may consider to extract with larger expanding value
|
# you may consider to extract with larger expanding value
|
||||||
expand_percentage=0,
|
expand_percentage=0,
|
||||||
)
|
)
|
||||||
|
@ -223,12 +223,8 @@ def __extract_faces_and_embeddings(
|
|||||||
embeddings = []
|
embeddings = []
|
||||||
facial_areas = []
|
facial_areas = []
|
||||||
|
|
||||||
model: FacialRecognition = modeling.build_model(model_name)
|
|
||||||
target_size = model.input_shape
|
|
||||||
|
|
||||||
img_objs = detection.extract_faces(
|
img_objs = detection.extract_faces(
|
||||||
img_path=img_path,
|
img_path=img_path,
|
||||||
target_size=target_size,
|
|
||||||
detector_backend=detector_backend,
|
detector_backend=detector_backend,
|
||||||
grayscale=False,
|
grayscale=False,
|
||||||
enforce_detection=enforce_detection,
|
enforce_detection=enforce_detection,
|
||||||
|
@ -18,6 +18,7 @@ model_names = [
|
|||||||
"Dlib",
|
"Dlib",
|
||||||
"ArcFace",
|
"ArcFace",
|
||||||
"SFace",
|
"SFace",
|
||||||
|
"GhostFaceNet",
|
||||||
]
|
]
|
||||||
|
|
||||||
detector_backends = [
|
detector_backends = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user