find function added

This commit is contained in:
Şefik Serangil 2020-05-25 14:59:51 +03:00
parent 48ed830a3d
commit 89142e6872

View File

@ -3,6 +3,7 @@ import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import time import time
import os import os
from os import path
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
@ -11,6 +12,7 @@ import cv2
from keras import backend as K from keras import backend as K
import keras import keras
import tensorflow as tf import tensorflow as tf
import pickle
#from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace #from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace
#from extendedmodels import Age, Gender, Race, Emotion #from extendedmodels import Age, Gender, Race, Emotion
@ -327,6 +329,136 @@ def detectFace(img_path):
img = functions.detectFace(img_path)[0] #detectFace returns (1, 224, 224, 3) img = functions.detectFace(img_path)[0] #detectFace returns (1, 224, 224, 3)
return img[:, :, ::-1] #bgr to rgb return img[:, :, ::-1] #bgr to rgb
def find(img_path, db_path
, model_name ='VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True):
tic = time.time()
if type(img_path) == list:
bulkProcess = True
img_paths = img_path.copy()
else:
bulkProcess = False
img_paths = [img_path]
if os.path.isdir(db_path) == True:
#---------------------------------------
if model == None:
if model_name == 'VGG-Face':
print("Using VGG-Face model backend and", distance_metric,"distance.")
model = VGGFace.loadModel()
elif model_name == 'OpenFace':
print("Using OpenFace model backend", distance_metric,"distance.")
model = OpenFace.loadModel()
elif model_name == 'Facenet':
print("Using Facenet model backend", distance_metric,"distance.")
model = Facenet.loadModel()
elif model_name == 'DeepFace':
print("Using FB DeepFace model backend", distance_metric,"distance.")
model = FbDeepFace.loadModel()
else:
raise ValueError("Invalid model_name passed - ", model_name)
else: #model != None
print("Already built model is passed")
input_shape = model.layers[0].input_shape[1:3]
threshold = functions.findThreshold(model_name, distance_metric)
#---------------------------------------
file_name = "representations_%s.pkl" % (model_name)
file_name = file_name.replace("-", "_").lower()
if path.exists(db_path+"/"+file_name):
print("WARNING: Representations for images in ",db_path," folder were previously stored in ", file_name, ". If you added new instances after this file creation, then please delete this file and call find function again. It will create it again.")
f = open(db_path+'/'+file_name, 'rb')
representations = pickle.load(f)
print("There are ", len(representations)," representations found in ",file_name)
else:
employees = []
for r, d, f in os.walk(db_path): # r=root, d=directories, f = files
for file in f:
if ('.jpg' in file):
exact_path = r + "/" + file
employees.append(exact_path)
if len(employees) == 0:
raise ValueError("There is no image in ", db_path," folder!")
#------------------------
#find representations for db images
representations = []
pbar = tqdm(range(0,len(employees)), desc='Finding representations')
#for employee in employees:
for index in pbar:
employee = employees[index]
img = functions.detectFace(employee, input_shape, enforce_detection = enforce_detection)
representation = model.predict(img)[0,:]
instance = []
instance.append(employee)
instance.append(representation)
representations.append(instance)
f = open(db_path+'/'+file_name, "wb")
pickle.dump(representations, f)
f.close()
print("Representations stored in ",db_path,"/",file_name," file. Please delete this file when you add new identities in your database.")
#----------------------------
#we got representations for database
df = pd.DataFrame(representations, columns = ["identity", "representation"])
df_base = df.copy()
resp_obj = []
global_pbar = tqdm(range(0,len(img_paths)), desc='Analyzing')
for j in global_pbar:
img_path = img_paths[j]
#find representation for passed image
img = functions.detectFace(img_path, input_shape, enforce_detection = enforce_detection)
target_representation = model.predict(img)[0,:]
distances = []
for index, instance in df.iterrows():
source_representation = instance["representation"]
distance = dst.findCosineDistance(source_representation, target_representation)
distances.append(distance)
df["distance"] = distances
df = df.drop(columns = ["representation"])
df = df[df.distance <= threshold]
df = df.sort_values(by = ["distance"], ascending=True).reset_index(drop=True)
resp_obj.append(df)
df = df_base.copy() #restore df for the next iteration
toc = time.time()
print("find function lasts ",toc-tic," seconds")
if len(resp_obj) == 1:
return resp_obj[0]
return resp_obj
else:
raise ValueError("Passed db_path does not exist!")
return None
def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', enable_face_analysis = True): def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', enable_face_analysis = True):
realtime.analysis(db_path, model_name, distance_metric, enable_face_analysis) realtime.analysis(db_path, model_name, distance_metric, enable_face_analysis)