handling many faces

This commit is contained in:
Sefik Ilkin Serengil 2023-01-23 21:54:52 +00:00
parent ba3db18671
commit b62e3671f8
6 changed files with 405 additions and 649 deletions

1
.gitignore vendored
View File

@ -17,6 +17,7 @@ deepface/extendedmodels/__pycache__/*
deepface/subsidiarymodels/__pycache__/* deepface/subsidiarymodels/__pycache__/*
deepface/detectors/__pycache__/* deepface/detectors/__pycache__/*
tests/dataset/*.pkl tests/dataset/*.pkl
tests/sandbox.ipynb
.DS_Store .DS_Store
deepface/.DS_Store deepface/.DS_Store
*.pyc *.pyc

File diff suppressed because it is too large Load Diff

View File

@ -27,24 +27,6 @@ elif tf_major_version == 2:
#-------------------------------------------------- #--------------------------------------------------
def initialize_input(img1_path, img2_path = None):
if type(img1_path) == list:
bulkProcess = True
img_list = img1_path.copy()
else:
bulkProcess = False
if (
(type(img2_path) == str and img2_path != None) #exact image path, base64 image
or (isinstance(img2_path, np.ndarray) and img2_path.any()) #numpy array
):
img_list = [[img1_path, img2_path]]
else: #analyze function passes just img1_path
img_list = [img1_path]
return img_list, bulkProcess
def initialize_folder(): def initialize_folder():
home = get_deepface_home() home = get_deepface_home()
@ -59,6 +41,8 @@ def initialize_folder():
def get_deepface_home(): def get_deepface_home():
return str(os.getenv('DEEPFACE_HOME', default=Path.home())) return str(os.getenv('DEEPFACE_HOME', default=Path.home()))
#--------------------------------------------------
def loadBase64Img(uri): def loadBase64Img(uri):
encoded_data = uri.split(',')[1] encoded_data = uri.split(',')[1]
nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8) nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8)
@ -93,35 +77,71 @@ def load_image(img):
return img return img
def detect_face(img, detector_backend = 'opencv', grayscale = False, enforce_detection = True, align = True): #--------------------------------------------------
def extract_faces(img, target_size=(224, 224), detector_backend = 'opencv', grayscale = False, enforce_detection = True, align = True):
# this is going to store a list of img itself (numpy), it region and confidence
extracted_faces = []
#img might be path, base64 or numpy array. Convert it to numpy whatever it is.
img = load_image(img)
img_region = [0, 0, img.shape[1], img.shape[0]] img_region = [0, 0, img.shape[1], img.shape[0]]
#----------------------------------------------
#people would like to skip detection and alignment if they already have pre-processed images
if detector_backend == 'skip': if detector_backend == 'skip':
return img, img_region face_objs = [(img, img_region, 0)]
#----------------------------------------------
#detector stored in a global variable in FaceDetector object.
#this call should be completed very fast because it will return found in memory
#it will not build face detector model in each call (consider for loops)
face_detector = FaceDetector.build_model(detector_backend)
try:
detected_face, img_region, _ = FaceDetector.detect_face(face_detector, detector_backend, img, align)
except: #if detected face shape is (0, 0) and alignment cannot be performed, this block will be run
detected_face = None
if (isinstance(detected_face, np.ndarray)):
return detected_face, img_region
else: else:
if detected_face == None: face_detector = FaceDetector.build_model(detector_backend)
if enforce_detection != True: face_objs = FaceDetector.detect_faces(face_detector, detector_backend, img, align)
return img, img_region
else: # in case of no face found
raise ValueError("Face could not be detected. Please confirm that the picture is a face photo or consider to set enforce_detection param to False.") if len(face_objs) == 0 and enforce_detection == True:
raise ValueError("Face could not be detected. Please confirm that the picture is a face photo or consider to set enforce_detection param to False.")
elif len(face_objs) == 0 and enforce_detection == False:
face_objs = [(img, img_region, 0)]
for current_img, current_region, confidence in face_objs:
if current_img.shape[0] > 0 and current_img.shape[1] > 0:
if grayscale == True:
current_img = cv2.cvtColor(current_img, cv2.COLOR_BGR2GRAY)
# resize and padding
if current_img.shape[0] > 0 and current_img.shape[1] > 0:
factor_0 = target_size[0] / current_img.shape[0]
factor_1 = target_size[1] / current_img.shape[1]
factor = min(factor_0, factor_1)
dsize = (int(current_img.shape[1] * factor), int(current_img.shape[0] * factor))
current_img = cv2.resize(current_img, dsize)
diff_0 = target_size[0] - current_img.shape[0]
diff_1 = target_size[1] - current_img.shape[1]
if grayscale == False:
# Put the base image in the middle of the padded image
current_img = np.pad(current_img, ((diff_0 // 2, diff_0 - diff_0 // 2), (diff_1 // 2, diff_1 - diff_1 // 2), (0, 0)), 'constant')
else:
current_img = np.pad(current_img, ((diff_0 // 2, diff_0 - diff_0 // 2), (diff_1 // 2, diff_1 - diff_1 // 2)), 'constant')
#double check: if target image is not still the same size with target.
if current_img.shape[0:2] != target_size:
current_img = cv2.resize(current_img, target_size)
#normalizing the image pixels
img_pixels = image.img_to_array(current_img) #what this line doing? must?
img_pixels = np.expand_dims(img_pixels, axis = 0)
img_pixels /= 255 #normalize input in [0, 1]
#int cast is for the exception - object of type 'float32' is not JSON serializable
region_obj = {"x": int(current_region[0]), "y": int(current_region[1]), "w": int(current_region[2]), "h": int(current_region[3])}
extracted_face = [img_pixels, region_obj, confidence]
extracted_faces.append(extracted_face)
if len(extracted_faces) == 0 and enforce_detection == True:
raise ValueError("Detected face shape is ", img.shape,". Consider to set enforce_detection argument to False.")
return extracted_faces
def normalize_input(img, normalization = 'base'): def normalize_input(img, normalization = 'base'):
@ -169,94 +189,21 @@ def normalize_input(img, normalization = 'base'):
return img return img
def preprocess_face(img, target_size=(224, 224), grayscale = False, enforce_detection = True, detector_backend = 'opencv', return_region = False, align = True): def find_target_size(model_name):
#img might be path, base64 or numpy array. Convert it to numpy whatever it is. target_sizes = {
img = load_image(img) "VGG-Face": (224, 224),
base_img = img.copy() "Facenet": (160, 160),
"Facenet512": (160, 160),
"OpenFace": (96, 96),
"DeepFace": (152, 152),
"DeepID": (55, 47), #TODO: might be opposite
"Dlib": (150, 150),
"ArcFace": (112, 112),
"SFace": (112, 112)
}
img, region = detect_face(img = img, detector_backend = detector_backend, grayscale = grayscale, enforce_detection = enforce_detection, align = align) if model_name not in target_sizes.keys():
raise ValueError(f"unimplemented model name - {model_name}")
#--------------------------
return target_sizes[model_name]
if img.shape[0] == 0 or img.shape[1] == 0:
if enforce_detection == True:
raise ValueError("Detected face shape is ", img.shape,". Consider to set enforce_detection argument to False.")
else: #restore base image
img = base_img.copy()
#--------------------------
#post-processing
if grayscale == True:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#---------------------------------------------------
#resize image to expected shape
# img = cv2.resize(img, target_size) #resize causes transformation on base image, adding black pixels to resize will not deform the base image
if img.shape[0] > 0 and img.shape[1] > 0:
factor_0 = target_size[0] / img.shape[0]
factor_1 = target_size[1] / img.shape[1]
factor = min(factor_0, factor_1)
dsize = (int(img.shape[1] * factor), int(img.shape[0] * factor))
img = cv2.resize(img, dsize)
# Then pad the other side to the target size by adding black pixels
diff_0 = target_size[0] - img.shape[0]
diff_1 = target_size[1] - img.shape[1]
if grayscale == False:
# Put the base image in the middle of the padded image
img = np.pad(img, ((diff_0 // 2, diff_0 - diff_0 // 2), (diff_1 // 2, diff_1 - diff_1 // 2), (0, 0)), 'constant')
else:
img = np.pad(img, ((diff_0 // 2, diff_0 - diff_0 // 2), (diff_1 // 2, diff_1 - diff_1 // 2)), 'constant')
#------------------------------------------
#double check: if target image is not still the same size with target.
if img.shape[0:2] != target_size:
img = cv2.resize(img, target_size)
#---------------------------------------------------
#normalizing the image pixels
img_pixels = image.img_to_array(img) #what this line doing? must?
img_pixels = np.expand_dims(img_pixels, axis = 0)
img_pixels /= 255 #normalize input in [0, 1]
#---------------------------------------------------
if return_region == True:
return img_pixels, region
else:
return img_pixels
def find_input_shape(model):
#face recognition models have different size of inputs
#my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
input_shape = model.layers[0].input_shape
if type(input_shape) == list:
input_shape = input_shape[0][1:3]
else:
input_shape = input_shape[1:3]
#----------------------
#issue 289: it seems that tf 2.5 expects you to resize images with (x, y)
#whereas its older versions expect (y, x)
if tf_major_version == 2 and tf_minor_version >= 5:
x = input_shape[0]; y = input_shape[1]
input_shape = (y, x)
#----------------------
if type(input_shape) == list: #issue 197: some people got array here instead of tuple
input_shape = tuple(input_shape)
return input_shape

View File

@ -50,7 +50,7 @@ def analysis(db_path, model_name = 'VGG-Face', detector_backend = 'opencv', dist
#------------------------ #------------------------
input_shape = functions.find_input_shape(model) input_shape = functions.find_target_size(model_name=model_name)
input_shape_x = input_shape[0]; input_shape_y = input_shape[1] input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
#tuned thresholds for model and metric pair #tuned thresholds for model and metric pair

BIN
tests/dataset/couple.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 923 KiB

View File

@ -1,6 +1,7 @@
import warnings import warnings
import os import os
import tensorflow as tf import tensorflow as tf
import numpy as np
import cv2 import cv2
from deepface import DeepFace from deepface import DeepFace