find function added

This commit is contained in:
Şefik Serangil 2020-05-25 14:59:51 +03:00
parent 48ed830a3d
commit 89142e6872

View File

@ -3,6 +3,7 @@ import warnings
warnings.filterwarnings("ignore")
import time
import os
from os import path
import numpy as np
import pandas as pd
from tqdm import tqdm
@ -11,6 +12,7 @@ import cv2
from keras import backend as K
import keras
import tensorflow as tf
import pickle
#from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace
#from extendedmodels import Age, Gender, Race, Emotion
@ -327,6 +329,136 @@ def detectFace(img_path):
img = functions.detectFace(img_path)[0] #detectFace returns (1, 224, 224, 3)
return img[:, :, ::-1] #bgr to rgb
def find(img_path, db_path
, model_name ='VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True):
tic = time.time()
if type(img_path) == list:
bulkProcess = True
img_paths = img_path.copy()
else:
bulkProcess = False
img_paths = [img_path]
if os.path.isdir(db_path) == True:
#---------------------------------------
if model == None:
if model_name == 'VGG-Face':
print("Using VGG-Face model backend and", distance_metric,"distance.")
model = VGGFace.loadModel()
elif model_name == 'OpenFace':
print("Using OpenFace model backend", distance_metric,"distance.")
model = OpenFace.loadModel()
elif model_name == 'Facenet':
print("Using Facenet model backend", distance_metric,"distance.")
model = Facenet.loadModel()
elif model_name == 'DeepFace':
print("Using FB DeepFace model backend", distance_metric,"distance.")
model = FbDeepFace.loadModel()
else:
raise ValueError("Invalid model_name passed - ", model_name)
else: #model != None
print("Already built model is passed")
input_shape = model.layers[0].input_shape[1:3]
threshold = functions.findThreshold(model_name, distance_metric)
#---------------------------------------
file_name = "representations_%s.pkl" % (model_name)
file_name = file_name.replace("-", "_").lower()
if path.exists(db_path+"/"+file_name):
print("WARNING: Representations for images in ",db_path," folder were previously stored in ", file_name, ". If you added new instances after this file creation, then please delete this file and call find function again. It will create it again.")
f = open(db_path+'/'+file_name, 'rb')
representations = pickle.load(f)
print("There are ", len(representations)," representations found in ",file_name)
else:
employees = []
for r, d, f in os.walk(db_path): # r=root, d=directories, f = files
for file in f:
if ('.jpg' in file):
exact_path = r + "/" + file
employees.append(exact_path)
if len(employees) == 0:
raise ValueError("There is no image in ", db_path," folder!")
#------------------------
#find representations for db images
representations = []
pbar = tqdm(range(0,len(employees)), desc='Finding representations')
#for employee in employees:
for index in pbar:
employee = employees[index]
img = functions.detectFace(employee, input_shape, enforce_detection = enforce_detection)
representation = model.predict(img)[0,:]
instance = []
instance.append(employee)
instance.append(representation)
representations.append(instance)
f = open(db_path+'/'+file_name, "wb")
pickle.dump(representations, f)
f.close()
print("Representations stored in ",db_path,"/",file_name," file. Please delete this file when you add new identities in your database.")
#----------------------------
#we got representations for database
df = pd.DataFrame(representations, columns = ["identity", "representation"])
df_base = df.copy()
resp_obj = []
global_pbar = tqdm(range(0,len(img_paths)), desc='Analyzing')
for j in global_pbar:
img_path = img_paths[j]
#find representation for passed image
img = functions.detectFace(img_path, input_shape, enforce_detection = enforce_detection)
target_representation = model.predict(img)[0,:]
distances = []
for index, instance in df.iterrows():
source_representation = instance["representation"]
distance = dst.findCosineDistance(source_representation, target_representation)
distances.append(distance)
df["distance"] = distances
df = df.drop(columns = ["representation"])
df = df[df.distance <= threshold]
df = df.sort_values(by = ["distance"], ascending=True).reset_index(drop=True)
resp_obj.append(df)
df = df_base.copy() #restore df for the next iteration
toc = time.time()
print("find function lasts ",toc-tic," seconds")
if len(resp_obj) == 1:
return resp_obj[0]
return resp_obj
else:
raise ValueError("Passed db_path does not exist!")
return None
def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', enable_face_analysis = True):
realtime.analysis(db_path, model_name, distance_metric, enable_face_analysis)