dlib input shape retrieved from class similar to others

This commit is contained in:
serengil 2020-11-30 11:11:16 +03:00
parent cdea61c043
commit edae2a799c
2 changed files with 33 additions and 38 deletions

View File

@ -219,10 +219,6 @@ def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric =
#face recognition models have different size of inputs #face recognition models have different size of inputs
#my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue. #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
if model_name == 'Dlib': #this is not a regular keras model
input_shape = (150, 150, 3)
else: #keras based models
input_shape = model.layers[0].input_shape input_shape = model.layers[0].input_shape
if type(input_shape) == list: if type(input_shape) == list:
@ -230,8 +226,7 @@ def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric =
else: else:
input_shape = input_shape[1:3] input_shape = input_shape[1:3]
input_shape_x = input_shape[0] input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
input_shape_y = input_shape[1]
#------------------------------ #------------------------------
@ -591,9 +586,6 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
if model_name != 'Ensemble': if model_name != 'Ensemble':
if model_name == 'Dlib': #non-keras model
input_shape = (150, 150, 3)
else:
#input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue. #input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
input_shape = model.layers[0].input_shape input_shape = model.layers[0].input_shape
@ -603,8 +595,6 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
else: else:
input_shape = input_shape[1:3] input_shape = input_shape[1:3]
#---------------------
input_shape_x = input_shape[0]; input_shape_y = input_shape[1] input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
img = functions.preprocess_face(img = employee, target_size = (input_shape_y, input_shape_x), enforce_detection = enforce_detection, detector_backend = detector_backend) img = functions.preprocess_face(img = employee, target_size = (input_shape_y, input_shape_x), enforce_detection = enforce_detection, detector_backend = detector_backend)
@ -754,9 +744,6 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
if model_name != 'Ensemble': if model_name != 'Ensemble':
if model_name == 'Dlib': #non-keras model
input_shape = (150, 150, 3)
else:
#input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue. #input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue.
input_shape = model.layers[0].input_shape input_shape = model.layers[0].input_shape
@ -766,10 +753,10 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine',
else: else:
input_shape = input_shape[1:3] input_shape = input_shape[1:3]
#------------------------
input_shape_x = input_shape[0]; input_shape_y = input_shape[1] input_shape_x = input_shape[0]; input_shape_y = input_shape[1]
#------------------------
img = functions.preprocess_face(img = img_path, target_size = (input_shape_y, input_shape_x), enforce_detection = enforce_detection, detector_backend = detector_backend) img = functions.preprocess_face(img = img_path, target_size = (input_shape_y, input_shape_x), enforce_detection = enforce_detection, detector_backend = detector_backend)
target_representation = model.predict(img)[0,:] target_representation = model.predict(img)[0,:]

View File

@ -10,6 +10,10 @@ class DlibResNet:
def __init__(self): def __init__(self):
self.layers = [DlibMetaData()]
#---------------------
home = str(Path.home()) home = str(Path.home())
weight_file = home+'/.deepface/weights/dlib_face_recognition_resnet_model_v1.dat' weight_file = home+'/.deepface/weights/dlib_face_recognition_resnet_model_v1.dat'
@ -60,3 +64,7 @@ class DlibResNet:
img_representation = np.expand_dims(img_representation, axis = 0) img_representation = np.expand_dims(img_representation, axis = 0)
return img_representation return img_representation
class DlibMetaData:
def __init__(self):
self.input_shape = [[1, 150, 150, 3]]