diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index de2f38a..35e4f5b 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -3,6 +3,7 @@ import warnings warnings.filterwarnings("ignore") import time import os +from os import path import numpy as np import pandas as pd from tqdm import tqdm @@ -11,6 +12,7 @@ import cv2 from keras import backend as K import keras import tensorflow as tf +import pickle #from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace #from extendedmodels import Age, Gender, Race, Emotion @@ -327,7 +329,137 @@ def detectFace(img_path): img = functions.detectFace(img_path)[0] #detectFace returns (1, 224, 224, 3) return img[:, :, ::-1] #bgr to rgb - +def find(img_path, db_path + , model_name ='VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True): + + tic = time.time() + + if type(img_path) == list: + bulkProcess = True + img_paths = img_path.copy() + else: + bulkProcess = False + img_paths = [img_path] + + if os.path.isdir(db_path) == True: + + #--------------------------------------- + + if model == None: + if model_name == 'VGG-Face': + print("Using VGG-Face model backend and", distance_metric,"distance.") + model = VGGFace.loadModel() + elif model_name == 'OpenFace': + print("Using OpenFace model backend", distance_metric,"distance.") + model = OpenFace.loadModel() + elif model_name == 'Facenet': + print("Using Facenet model backend", distance_metric,"distance.") + model = Facenet.loadModel() + elif model_name == 'DeepFace': + print("Using FB DeepFace model backend", distance_metric,"distance.") + model = FbDeepFace.loadModel() + else: + raise ValueError("Invalid model_name passed - ", model_name) + else: #model != None + print("Already built model is passed") + + input_shape = model.layers[0].input_shape[1:3] + threshold = functions.findThreshold(model_name, distance_metric) + + #--------------------------------------- + + file_name = "representations_%s.pkl" % (model_name) + file_name = file_name.replace("-", "_").lower() + + if path.exists(db_path+"/"+file_name): + + print("WARNING: Representations for images in ",db_path," folder were previously stored in ", file_name, ". If you added new instances after this file creation, then please delete this file and call find function again. It will create it again.") + + f = open(db_path+'/'+file_name, 'rb') + representations = pickle.load(f) + + print("There are ", len(representations)," representations found in ",file_name) + + else: + employees = [] + + for r, d, f in os.walk(db_path): # r=root, d=directories, f = files + for file in f: + if ('.jpg' in file): + exact_path = r + "/" + file + employees.append(exact_path) + + if len(employees) == 0: + raise ValueError("There is no image in ", db_path," folder!") + + #------------------------ + #find representations for db images + + representations = [] + + pbar = tqdm(range(0,len(employees)), desc='Finding representations') + + #for employee in employees: + for index in pbar: + employee = employees[index] + img = functions.detectFace(employee, input_shape, enforce_detection = enforce_detection) + representation = model.predict(img)[0,:] + + instance = [] + instance.append(employee) + instance.append(representation) + + representations.append(instance) + + f = open(db_path+'/'+file_name, "wb") + pickle.dump(representations, f) + f.close() + + print("Representations stored in ",db_path,"/",file_name," file. Please delete this file when you add new identities in your database.") + + #---------------------------- + #we got representations for database + df = pd.DataFrame(representations, columns = ["identity", "representation"]) + df_base = df.copy() + + resp_obj = [] + + global_pbar = tqdm(range(0,len(img_paths)), desc='Analyzing') + for j in global_pbar: + img_path = img_paths[j] + + #find representation for passed image + img = functions.detectFace(img_path, input_shape, enforce_detection = enforce_detection) + target_representation = model.predict(img)[0,:] + + distances = [] + for index, instance in df.iterrows(): + source_representation = instance["representation"] + distance = dst.findCosineDistance(source_representation, target_representation) + distances.append(distance) + + df["distance"] = distances + df = df.drop(columns = ["representation"]) + df = df[df.distance <= threshold] + + df = df.sort_values(by = ["distance"], ascending=True).reset_index(drop=True) + resp_obj.append(df) + df = df_base.copy() #restore df for the next iteration + + toc = time.time() + + print("find function lasts ",toc-tic," seconds") + + if len(resp_obj) == 1: + return resp_obj[0] + + return resp_obj + + else: + raise ValueError("Passed db_path does not exist!") + + return None + def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', enable_face_analysis = True): realtime.analysis(db_path, model_name, distance_metric, enable_face_analysis)