mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 20:15:21 +00:00
find function added
This commit is contained in:
parent
48ed830a3d
commit
89142e6872
@ -3,6 +3,7 @@ import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
import time
|
||||
import os
|
||||
from os import path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
@ -11,6 +12,7 @@ import cv2
|
||||
from keras import backend as K
|
||||
import keras
|
||||
import tensorflow as tf
|
||||
import pickle
|
||||
|
||||
#from basemodels import VGGFace, OpenFace, Facenet, FbDeepFace
|
||||
#from extendedmodels import Age, Gender, Race, Emotion
|
||||
@ -327,6 +329,136 @@ def detectFace(img_path):
|
||||
img = functions.detectFace(img_path)[0] #detectFace returns (1, 224, 224, 3)
|
||||
return img[:, :, ::-1] #bgr to rgb
|
||||
|
||||
def find(img_path, db_path
|
||||
, model_name ='VGG-Face', distance_metric = 'cosine', model = None, enforce_detection = True):
|
||||
|
||||
tic = time.time()
|
||||
|
||||
if type(img_path) == list:
|
||||
bulkProcess = True
|
||||
img_paths = img_path.copy()
|
||||
else:
|
||||
bulkProcess = False
|
||||
img_paths = [img_path]
|
||||
|
||||
if os.path.isdir(db_path) == True:
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
if model == None:
|
||||
if model_name == 'VGG-Face':
|
||||
print("Using VGG-Face model backend and", distance_metric,"distance.")
|
||||
model = VGGFace.loadModel()
|
||||
elif model_name == 'OpenFace':
|
||||
print("Using OpenFace model backend", distance_metric,"distance.")
|
||||
model = OpenFace.loadModel()
|
||||
elif model_name == 'Facenet':
|
||||
print("Using Facenet model backend", distance_metric,"distance.")
|
||||
model = Facenet.loadModel()
|
||||
elif model_name == 'DeepFace':
|
||||
print("Using FB DeepFace model backend", distance_metric,"distance.")
|
||||
model = FbDeepFace.loadModel()
|
||||
else:
|
||||
raise ValueError("Invalid model_name passed - ", model_name)
|
||||
else: #model != None
|
||||
print("Already built model is passed")
|
||||
|
||||
input_shape = model.layers[0].input_shape[1:3]
|
||||
threshold = functions.findThreshold(model_name, distance_metric)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
file_name = "representations_%s.pkl" % (model_name)
|
||||
file_name = file_name.replace("-", "_").lower()
|
||||
|
||||
if path.exists(db_path+"/"+file_name):
|
||||
|
||||
print("WARNING: Representations for images in ",db_path," folder were previously stored in ", file_name, ". If you added new instances after this file creation, then please delete this file and call find function again. It will create it again.")
|
||||
|
||||
f = open(db_path+'/'+file_name, 'rb')
|
||||
representations = pickle.load(f)
|
||||
|
||||
print("There are ", len(representations)," representations found in ",file_name)
|
||||
|
||||
else:
|
||||
employees = []
|
||||
|
||||
for r, d, f in os.walk(db_path): # r=root, d=directories, f = files
|
||||
for file in f:
|
||||
if ('.jpg' in file):
|
||||
exact_path = r + "/" + file
|
||||
employees.append(exact_path)
|
||||
|
||||
if len(employees) == 0:
|
||||
raise ValueError("There is no image in ", db_path," folder!")
|
||||
|
||||
#------------------------
|
||||
#find representations for db images
|
||||
|
||||
representations = []
|
||||
|
||||
pbar = tqdm(range(0,len(employees)), desc='Finding representations')
|
||||
|
||||
#for employee in employees:
|
||||
for index in pbar:
|
||||
employee = employees[index]
|
||||
img = functions.detectFace(employee, input_shape, enforce_detection = enforce_detection)
|
||||
representation = model.predict(img)[0,:]
|
||||
|
||||
instance = []
|
||||
instance.append(employee)
|
||||
instance.append(representation)
|
||||
|
||||
representations.append(instance)
|
||||
|
||||
f = open(db_path+'/'+file_name, "wb")
|
||||
pickle.dump(representations, f)
|
||||
f.close()
|
||||
|
||||
print("Representations stored in ",db_path,"/",file_name," file. Please delete this file when you add new identities in your database.")
|
||||
|
||||
#----------------------------
|
||||
#we got representations for database
|
||||
df = pd.DataFrame(representations, columns = ["identity", "representation"])
|
||||
df_base = df.copy()
|
||||
|
||||
resp_obj = []
|
||||
|
||||
global_pbar = tqdm(range(0,len(img_paths)), desc='Analyzing')
|
||||
for j in global_pbar:
|
||||
img_path = img_paths[j]
|
||||
|
||||
#find representation for passed image
|
||||
img = functions.detectFace(img_path, input_shape, enforce_detection = enforce_detection)
|
||||
target_representation = model.predict(img)[0,:]
|
||||
|
||||
distances = []
|
||||
for index, instance in df.iterrows():
|
||||
source_representation = instance["representation"]
|
||||
distance = dst.findCosineDistance(source_representation, target_representation)
|
||||
distances.append(distance)
|
||||
|
||||
df["distance"] = distances
|
||||
df = df.drop(columns = ["representation"])
|
||||
df = df[df.distance <= threshold]
|
||||
|
||||
df = df.sort_values(by = ["distance"], ascending=True).reset_index(drop=True)
|
||||
resp_obj.append(df)
|
||||
df = df_base.copy() #restore df for the next iteration
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print("find function lasts ",toc-tic," seconds")
|
||||
|
||||
if len(resp_obj) == 1:
|
||||
return resp_obj[0]
|
||||
|
||||
return resp_obj
|
||||
|
||||
else:
|
||||
raise ValueError("Passed db_path does not exist!")
|
||||
|
||||
return None
|
||||
|
||||
def stream(db_path = '', model_name ='VGG-Face', distance_metric = 'cosine', enable_face_analysis = True):
|
||||
realtime.analysis(db_path, model_name, distance_metric, enable_face_analysis)
|
||||
|
Loading…
x
Reference in New Issue
Block a user