mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 20:15:21 +00:00
find function added
This commit is contained in:
parent
48ed830a3d
commit
89142e6872
@ -3,6 +3,7 @@ import warnings
|
|||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
from os import path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -11,6 +12,7 @@ import cv2
|
|||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
import keras
|
import keras
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import pickle
|
||||||
|
|
||||||
#from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace
|
#from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace
|
||||||
#from extendedmodels import Age, Gender, Race, Emotion
|
#from extendedmodels import Age, Gender, Race, Emotion
|
||||||
@ -327,6 +329,136 @@ def detectFace(img_path):
|
|||||||
img = functions.detectFace(img_path)[0] #detectFace returns (1, 224, 224, 3)
|
img = functions.detectFace(img_path)[0] #detectFace returns (1, 224, 224, 3)
|
||||||
return img[:, :, ::-1] #bgr to rgb
|
return img[:, :, ::-1] #bgr to rgb
|
||||||
|
|
||||||
|
def find(img_path, db_path
|
||||||
|
, model_name ='VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True):
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
if type(img_path) == list:
|
||||||
|
bulkProcess = True
|
||||||
|
img_paths = img_path.copy()
|
||||||
|
else:
|
||||||
|
bulkProcess = False
|
||||||
|
img_paths = [img_path]
|
||||||
|
|
||||||
|
if os.path.isdir(db_path) == True:
|
||||||
|
|
||||||
|
#---------------------------------------
|
||||||
|
|
||||||
|
if model == None:
|
||||||
|
if model_name == 'VGG-Face':
|
||||||
|
print("Using VGG-Face model backend and", distance_metric,"distance.")
|
||||||
|
model = VGGFace.loadModel()
|
||||||
|
elif model_name == 'OpenFace':
|
||||||
|
print("Using OpenFace model backend", distance_metric,"distance.")
|
||||||
|
model = OpenFace.loadModel()
|
||||||
|
elif model_name == 'Facenet':
|
||||||
|
print("Using Facenet model backend", distance_metric,"distance.")
|
||||||
|
model = Facenet.loadModel()
|
||||||
|
elif model_name == 'DeepFace':
|
||||||
|
print("Using FB DeepFace model backend", distance_metric,"distance.")
|
||||||
|
model = FbDeepFace.loadModel()
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid model_name passed - ", model_name)
|
||||||
|
else: #model != None
|
||||||
|
print("Already built model is passed")
|
||||||
|
|
||||||
|
input_shape = model.layers[0].input_shape[1:3]
|
||||||
|
threshold = functions.findThreshold(model_name, distance_metric)
|
||||||
|
|
||||||
|
#---------------------------------------
|
||||||
|
|
||||||
|
file_name = "representations_%s.pkl" % (model_name)
|
||||||
|
file_name = file_name.replace("-", "_").lower()
|
||||||
|
|
||||||
|
if path.exists(db_path+"/"+file_name):
|
||||||
|
|
||||||
|
print("WARNING: Representations for images in ",db_path," folder were previously stored in ", file_name, ". If you added new instances after this file creation, then please delete this file and call find function again. It will create it again.")
|
||||||
|
|
||||||
|
f = open(db_path+'/'+file_name, 'rb')
|
||||||
|
representations = pickle.load(f)
|
||||||
|
|
||||||
|
print("There are ", len(representations)," representations found in ",file_name)
|
||||||
|
|
||||||
|
else:
|
||||||
|
employees = []
|
||||||
|
|
||||||
|
for r, d, f in os.walk(db_path): # r=root, d=directories, f = files
|
||||||
|
for file in f:
|
||||||
|
if ('.jpg' in file):
|
||||||
|
exact_path = r + "/" + file
|
||||||
|
employees.append(exact_path)
|
||||||
|
|
||||||
|
if len(employees) == 0:
|
||||||
|
raise ValueError("There is no image in ", db_path," folder!")
|
||||||
|
|
||||||
|
#------------------------
|
||||||
|
#find representations for db images
|
||||||
|
|
||||||
|
representations = []
|
||||||
|
|
||||||
|
pbar = tqdm(range(0,len(employees)), desc='Finding representations')
|
||||||
|
|
||||||
|
#for employee in employees:
|
||||||
|
for index in pbar:
|
||||||
|
employee = employees[index]
|
||||||
|
img = functions.detectFace(employee, input_shape, enforce_detection = enforce_detection)
|
||||||
|
representation = model.predict(img)[0,:]
|
||||||
|
|
||||||
|
instance = []
|
||||||
|
instance.append(employee)
|
||||||
|
instance.append(representation)
|
||||||
|
|
||||||
|
representations.append(instance)
|
||||||
|
|
||||||
|
f = open(db_path+'/'+file_name, "wb")
|
||||||
|
pickle.dump(representations, f)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
print("Representations stored in ",db_path,"/",file_name," file. Please delete this file when you add new identities in your database.")
|
||||||
|
|
||||||
|
#----------------------------
|
||||||
|
#we got representations for database
|
||||||
|
df = pd.DataFrame(representations, columns = ["identity", "representation"])
|
||||||
|
df_base = df.copy()
|
||||||
|
|
||||||
|
resp_obj = []
|
||||||
|
|
||||||
|
global_pbar = tqdm(range(0,len(img_paths)), desc='Analyzing')
|
||||||
|
for j in global_pbar:
|
||||||
|
img_path = img_paths[j]
|
||||||
|
|
||||||
|
#find representation for passed image
|
||||||
|
img = functions.detectFace(img_path, input_shape, enforce_detection = enforce_detection)
|
||||||
|
target_representation = model.predict(img)[0,:]
|
||||||
|
|
||||||
|
distances = []
|
||||||
|
for index, instance in df.iterrows():
|
||||||
|
source_representation = instance["representation"]
|
||||||
|
distance = dst.findCosineDistance(source_representation, target_representation)
|
||||||
|
distances.append(distance)
|
||||||
|
|
||||||
|
df["distance"] = distances
|
||||||
|
df = df.drop(columns = ["representation"])
|
||||||
|
df = df[df.distance <= threshold]
|
||||||
|
|
||||||
|
df = df.sort_values(by = ["distance"], ascending=True).reset_index(drop=True)
|
||||||
|
resp_obj.append(df)
|
||||||
|
df = df_base.copy() #restore df for the next iteration
|
||||||
|
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
print("find function lasts ",toc-tic," seconds")
|
||||||
|
|
||||||
|
if len(resp_obj) == 1:
|
||||||
|
return resp_obj[0]
|
||||||
|
|
||||||
|
return resp_obj
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Passed db_path does not exist!")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', enable_face_analysis = True):
|
def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', enable_face_analysis = True):
|
||||||
realtime.analysis(db_path, model_name, distance_metric, enable_face_analysis)
|
realtime.analysis(db_path, model_name, distance_metric, enable_face_analysis)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user