detector argument for stream

This commit is contained in:
Sefik Ilkin Serengil 2021-07-27 00:15:15 +03:00
parent af13e4558f
commit d3e03c3e5a
12 changed files with 155 additions and 76 deletions

View File

@ -746,7 +746,7 @@ def represent(img_path, model_name = 'VGG-Face', model = None, enforce_detection
return embedding
def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', enable_face_analysis = True, source = 0, time_threshold = 5, frame_threshold = 5):
def stream(db_path = '', model_name ='VGG-Face', detector_backend = 'opencv', distance_metric = 'cosine', enable_face_analysis = True, source = 0, time_threshold = 5, frame_threshold = 5):
"""
This function applies real time face recognition and facial attribute analysis
@ -756,6 +756,8 @@ def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', ena
model_name (string): VGG-Face, Facenet, OpenFace, DeepFace, DeepID, Dlib or Ensemble
detector_backend (string): opencv, ssd, mtcnn, dlib, retinaface
distance_metric (string): cosine, euclidean, euclidean_l2
enable_facial_analysis (boolean): Set this to False to just run face recognition
@ -774,7 +776,7 @@ def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', ena
if frame_threshold < 1:
raise ValueError("frame_threshold must be greater than the value 1 but you passed "+str(frame_threshold))
realtime.analysis(db_path, model_name, distance_metric, enable_face_analysis
realtime.analysis(db_path, model_name, detector_backend, distance_metric, enable_face_analysis
, source = source, time_threshold = time_threshold, frame_threshold = frame_threshold)
def detectFace(img_path, detector_backend = 'opencv', enforce_detection = True):

View File

@ -26,14 +26,16 @@ def findThreshold(model_name, distance_metric):
base_threshold = {'cosine': 0.40, 'euclidean': 0.55, 'euclidean_l2': 0.75}
thresholds = {
'VGG-Face': {'cosine': 0.40, 'euclidean': 0.55, 'euclidean_l2': 0.75},
'VGG-Face': {'cosine': 0.40, 'euclidean': 0.60, 'euclidean_l2': 0.86},
'Facenet': {'cosine': 0.40, 'euclidean': 10, 'euclidean_l2': 0.80},
'Facenet512': {'cosine': 0.30, 'euclidean': 23.56, 'euclidean_l2': 1.04},
'ArcFace': {'cosine': 0.68, 'euclidean': 4.15, 'euclidean_l2': 1.13},
'Dlib': {'cosine': 0.07, 'euclidean': 0.6, 'euclidean_l2': 0.4},
'OpenFace': {'cosine': 0.10, 'euclidean': 0.55, 'euclidean_l2': 0.55},
'Facenet': {'cosine': 0.40, 'euclidean': 10, 'euclidean_l2': 0.80},
'Facenet512': {'cosine': 0.3088582207770799, 'euclidean': 23.564685968740186, 'euclidean_l2': 1.0461709266148662},
'DeepFace': {'cosine': 0.23, 'euclidean': 64, 'euclidean_l2': 0.64},
'DeepID': {'cosine': 0.015, 'euclidean': 45, 'euclidean_l2': 0.17},
'Dlib': {'cosine': 0.07, 'euclidean': 0.6, 'euclidean_l2': 0.6},
'ArcFace': {'cosine': 0.6871912959056619, 'euclidean': 4.1591468986978075, 'euclidean_l2': 1.1315718048269017}
'DeepID': {'cosine': 0.015, 'euclidean': 45, 'euclidean_l2': 0.17}
}
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)

View File

@ -123,19 +123,16 @@ def preprocess_face(img, target_size=(224, 224), grayscale = False, enforce_dete
#post-processing
if grayscale == True:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#---------------------------------------------------
#resize image to expected shape
# img = cv2.resize(img, target_size) #resize causes transformation on base image, adding black pixels to resize will not deform the base image
# First resize the longer side to the target size
#factor = target_size[0] / max(img.shape)
factor_0 = target_size[0] / img.shape[0]
factor_1 = target_size[1] / img.shape[1]
factor = min(factor_0, factor_1)
dsize = (int(img.shape[1] * factor), int(img.shape[0] * factor))
img = cv2.resize(img, dsize)
@ -147,13 +144,13 @@ def preprocess_face(img, target_size=(224, 224), grayscale = False, enforce_dete
img = np.pad(img, ((diff_0 // 2, diff_0 - diff_0 // 2), (diff_1 // 2, diff_1 - diff_1 // 2), (0, 0)), 'constant')
else:
img = np.pad(img, ((diff_0 // 2, diff_0 - diff_0 // 2), (diff_1 // 2, diff_1 - diff_1 // 2)), 'constant')
#double check: if target image is not still the same size with target.
if img.shape[0:2] != target_size:
img = cv2.resize(img, target_size)
#---------------------------------------------------
img_pixels = image.img_to_array(img)
img_pixels = np.expand_dims(img_pixels, axis = 0)
img_pixels /= 255 #normalize input in [0, 1]

View File

@ -12,10 +12,16 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from deepface import DeepFace
from deepface.extendedmodels import Age
from deepface.commons import functions, realtime, distance as dst
from deepface.detectors import OpenCvWrapper
from deepface.detectors import FaceDetector
def analysis(db_path, model_name, distance_metric, enable_face_analysis = True
, source = 0, time_threshold = 5, frame_threshold = 5):
def analysis(db_path, model_name = 'VGG-Face', detector_backend = 'opencv', distance_metric = 'cosine', enable_face_analysis = True, source = 0, time_threshold = 5, frame_threshold = 5):
#------------------------
face_detector = FaceDetector.build_model(detector_backend)
print("Detector backend is ", detector_backend)
#------------------------
input_shape = (224, 224); input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
@ -45,8 +51,7 @@ def analysis(db_path, model_name, distance_metric, enable_face_analysis = True
#------------------------
input_shape = functions.find_input_shape(model)
input_shape_x = input_shape[0]
input_shape_y = input_shape[1]
input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
#tuned thresholds for model and metric pair
threshold = dst.findThreshold(model_name, distance_metric)
@ -77,15 +82,21 @@ def analysis(db_path, model_name, distance_metric, enable_face_analysis = True
tic = time.time()
#-----------------------
pbar = tqdm(range(0, len(employees)), desc='Finding embeddings')
#TODO: why don't you store those embeddings in a pickle file similar to find function?
embeddings = []
#for employee in employees:
for index in pbar:
employee = employees[index]
pbar.set_description("Finding embedding for %s" % (employee.split("/")[-1]))
embedding = []
img = functions.preprocess_face(img = employee, target_size = (input_shape_y, input_shape_x), enforce_detection = False, detector_backend = 'opencv')
#preprocess_face returns single face. this is expected for source images in db.
img = functions.preprocess_face(img = employee, target_size = (input_shape_y, input_shape_x), enforce_detection = False, detector_backend = detector_backend)
img_representation = model.predict(img)[0,:]
embedding.append(employee)
@ -105,12 +116,6 @@ def analysis(db_path, model_name, distance_metric, enable_face_analysis = True
#-----------------------
opencv_path = OpenCvWrapper.get_opencv_path()
face_detector_path = opencv_path+"haarcascade_frontalface_default.xml"
face_cascade = cv2.CascadeClassifier(face_detector_path)
#-----------------------
freeze = False
face_detected = False
face_included_frames = 0 #freeze screen if face detected sequantially 5 frames
@ -129,12 +134,13 @@ def analysis(db_path, model_name, distance_metric, enable_face_analysis = True
#cv2.setWindowProperty('img', cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
raw_img = img.copy()
resolution = img.shape
resolution_x = img.shape[1]; resolution_y = img.shape[0]
resolution = img.shape; resolution_x = img.shape[1]; resolution_y = img.shape[0]
if freeze == False:
faces = face_cascade.detectMultiScale(img, 1.3, 5)
#faces = face_cascade.detectMultiScale(img, 1.3, 5)
#faces stores list of detected_face and region pair
faces = FaceDetector.detect_faces(face_detector, detector_backend, img, align = False)
if len(faces) == 0:
face_included_frames = 0
@ -143,7 +149,7 @@ def analysis(db_path, model_name, distance_metric, enable_face_analysis = True
detected_faces = []
face_index = 0
for (x,y,w,h) in faces:
for face, (x, y, w, h) in faces:
if w > 130: #discard small detected faces
face_detected = True

View File

@ -36,6 +36,8 @@ def detect_face(detector, img, align = True):
import dlib #this requirement is not a must that's why imported here
resp = []
home = str(Path.home())
sp = detector["sp"]
@ -53,10 +55,12 @@ def detect_face(detector, img, align = True):
top = d.top(); bottom = d.bottom()
detected_face = img[top:bottom, left:right]
img_region = [left, top, right - left, bottom - top]
break #get the first one
if align:
img_shape = sp(img, detections[0])
detected_face = dlib.get_face_chip(img, img_shape, size = detected_face.shape[0])
if align:
img_shape = sp(img, detections[idx])
detected_face = dlib.get_face_chip(img, img_shape, size = detected_face.shape[0])
return detected_face, img_region
resp.append((detected_face, img_region))
return resp

View File

@ -33,6 +33,18 @@ def build_model(detector_backend):
def detect_face(face_detector, detector_backend, img, align = True):
obj = detect_faces(face_detector, detector_backend, img, align)
if len(obj) > 0:
face, region = obj[0] #discard multiple faces
else: #len(obj) == 0
face = None
region = [0, 0, img.shape[0], img.shape[1]]
return face, region
def detect_faces(face_detector, detector_backend, img, align = True):
backends = {
'opencv': OpenCvWrapper.detect_face,
'ssd': SsdWrapper.detect_face,
@ -44,12 +56,13 @@ def detect_face(face_detector, detector_backend, img, align = True):
detect_face = backends.get(detector_backend)
if detect_face:
face, region = detect_face(face_detector, img, align)
obj = detect_face(face_detector, img, align)
#obj stores list of detected_face and region pair
return obj
else:
raise ValueError("invalid detector_backend passed - " + detector_backend)
return face, region
def alignment_procedure(img, left_eye, right_eye):
#this function aligns given face in img based on left and right eye coordinates

View File

@ -8,6 +8,8 @@ def build_model():
def detect_face(face_detector, img, align = True):
resp = []
detected_face = None
img_region = [0, 0, img.shape[0], img.shape[1]]
@ -15,16 +17,18 @@ def detect_face(face_detector, img, align = True):
detections = face_detector.detect_faces(img_rgb)
if len(detections) > 0:
detection = detections[0]
x, y, w, h = detection["box"]
detected_face = img[int(y):int(y+h), int(x):int(x+w)]
img_region = [x, y, w, h]
keypoints = detection["keypoints"]
left_eye = keypoints["left_eye"]
right_eye = keypoints["right_eye"]
for detection in detections:
x, y, w, h = detection["box"]
detected_face = img[int(y):int(y+h), int(x):int(x+w)]
img_region = [x, y, w, h]
if align:
detected_face = FaceDetector.alignment_procedure(detected_face, left_eye, right_eye)
if align:
keypoints = detection["keypoints"]
left_eye = keypoints["left_eye"]
right_eye = keypoints["right_eye"]
detected_face = FaceDetector.alignment_procedure(detected_face, left_eye, right_eye)
return detected_face, img_region
resp.append((detected_face, img_region))
return resp

View File

@ -37,6 +37,8 @@ def build_cascade(model_name = 'haarcascade'):
def detect_face(detector, img, align = True):
resp = []
detected_face = None
img_region = [0, 0, img.shape[0], img.shape[1]]
@ -48,14 +50,18 @@ def detect_face(detector, img, align = True):
pass
if len(faces) > 0:
x,y,w,h = faces[0] #focus on the 1st face found in the image
detected_face = img[int(y):int(y+h), int(x):int(x+w)]
if align:
detected_face = align_face(detector["eye_detector"], detected_face)
img_region = [x, y, w, h]
for x,y,w,h in faces:
detected_face = img[int(y):int(y+h), int(x):int(x+w)]
return detected_face, img_region
if align:
detected_face = align_face(detector["eye_detector"], detected_face)
img_region = [x, y, w, h]
resp.append((detected_face, img_region))
return resp
def align_face(eye_detector, img):

View File

@ -1,4 +1,4 @@
#from retinaface import RetinaFace
#from retinaface import RetinaFace #this is not a must dependency
import cv2
def build_model():
@ -7,12 +7,19 @@ def build_model():
return face_detector
def detect_face(face_detector, img, align = True):
from retinaface import RetinaFace
from retinaface.commons import postprocess
#---------------------------------
resp = []
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #retinaface expects RGB but OpenCV read BGR
"""
face = None
img_region = [0, 0, img.shape[0], img.shape[1]]
img_region = [0, 0, img.shape[0], img.shape[1]] #Really?
faces = RetinaFace.extract_faces(img_rgb, model = face_detector, align = align)
@ -20,3 +27,36 @@ def detect_face(face_detector, img, align = True):
face = faces[0][:, :, ::-1]
return face, img_region
"""
#--------------------------
obj = RetinaFace.detect_faces(img_rgb, model = face_detector, threshold = 0.9)
if type(obj) == dict:
for key in obj:
identity = obj[key]
facial_area = identity["facial_area"]
y = facial_area[1]
h = facial_area[3] - y
x = facial_area[0]
w = facial_area[2] - x
img_region = [x, y, w, h]
#detected_face = img[int(y):int(y+h), int(x):int(x+w)] #opencv
detected_face = img[facial_area[1]: facial_area[3], facial_area[0]: facial_area[2]]
if align:
landmarks = identity["landmarks"]
left_eye = landmarks["left_eye"]
right_eye = landmarks["right_eye"]
nose = landmarks["nose"]
#mouth_right = landmarks["mouth_right"]
#mouth_left = landmarks["mouth_left"]
detected_face = postprocess.alignment_procedure(detected_face, right_eye, left_eye, nose)
resp.append((detected_face, img_region))
return resp

View File

@ -47,6 +47,8 @@ def build_model():
def detect_face(detector, img, align = True):
resp = []
detected_face = None
img_region = [0, 0, img.shape[0], img.shape[1]]
@ -81,20 +83,19 @@ def detect_face(detector, img, align = True):
if detections_df.shape[0] > 0:
#TODO: sort detections_df
for index, instance in detections_df.iterrows():
#get the first face in the image
instance = detections_df.iloc[0]
left = instance["left"]
right = instance["right"]
bottom = instance["bottom"]
top = instance["top"]
left = instance["left"]
right = instance["right"]
bottom = instance["bottom"]
top = instance["top"]
detected_face = base_img[int(top*aspect_ratio_y):int(bottom*aspect_ratio_y), int(left*aspect_ratio_x):int(right*aspect_ratio_x)]
img_region = [int(left*aspect_ratio_x), int(top*aspect_ratio_y), int(right*aspect_ratio_x) - int(left*aspect_ratio_x), int(bottom*aspect_ratio_y) - int(top*aspect_ratio_y)]
detected_face = base_img[int(top*aspect_ratio_y):int(bottom*aspect_ratio_y), int(left*aspect_ratio_x):int(right*aspect_ratio_x)]
img_region = [int(left*aspect_ratio_x), int(top*aspect_ratio_y), int(right*aspect_ratio_x) - int(left*aspect_ratio_x), int(bottom*aspect_ratio_y) - int(top*aspect_ratio_y)]
if align:
detected_face = OpenCvWrapper.align_face(detector["eye_detector"], detected_face)
if align:
detected_face = OpenCvWrapper.align_face(detector["eye_detector"], detected_face)
resp.append((detected_face, img_region))
return detected_face, img_region
return resp

View File

@ -1,3 +1,8 @@
from deepface import DeepFace
DeepFace.stream("dataset")
DeepFace.stream("dataset") #opencv
#DeepFace.stream("dataset", detector_backend = 'opencv')
#DeepFace.stream("dataset", detector_backend = 'ssd')
#DeepFace.stream("dataset", detector_backend = 'mtcnn')
#DeepFace.stream("dataset", detector_backend = 'dlib')
#DeepFace.stream("dataset", detector_backend = 'retinaface')

View File

@ -170,7 +170,6 @@ dataset = [
['dataset/img5.jpg', 'dataset/img6.jpg', True],
['dataset/img6.jpg', 'dataset/img7.jpg', True],
['dataset/img8.jpg', 'dataset/img9.jpg', True],
['dataset/img1.jpg', 'dataset/img11.jpg', True],
['dataset/img2.jpg', 'dataset/img11.jpg', True],