diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index a0d434b..c5620c3 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -219,19 +219,14 @@ def verify(img1_path, img2_path = '', model_name = 'VGG-Face', distance_metric = #face recognition models have different size of inputs #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue. - if model_name == 'Dlib': #this is not a regular keras model - input_shape = (150, 150, 3) + input_shape = model.layers[0].input_shape - else: #keras based models - input_shape = model.layers[0].input_shape - - if type(input_shape) == list: - input_shape = input_shape[0][1:3] - else: - input_shape = input_shape[1:3] - - input_shape_x = input_shape[0] - input_shape_y = input_shape[1] + if type(input_shape) == list: + input_shape = input_shape[0][1:3] + else: + input_shape = input_shape[1:3] + + input_shape_x = input_shape[0]; input_shape_y = input_shape[1] #------------------------------ @@ -590,20 +585,15 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine', employee = employees[index] if model_name != 'Ensemble': - - if model_name == 'Dlib': #non-keras model - input_shape = (150, 150, 3) - else: - #input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue. - - input_shape = model.layers[0].input_shape - - if type(input_shape) == list: - input_shape = input_shape[0][1:3] - else: - input_shape = input_shape[1:3] - #--------------------- + #input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue. + + input_shape = model.layers[0].input_shape + + if type(input_shape) == list: + input_shape = input_shape[0][1:3] + else: + input_shape = input_shape[1:3] input_shape_x = input_shape[0]; input_shape_y = input_shape[1] @@ -754,22 +744,19 @@ def find(img_path, db_path, model_name ='VGG-Face', distance_metric = 'cosine', if model_name != 'Ensemble': - if model_name == 'Dlib': #non-keras model - input_shape = (150, 150, 3) - else: - #input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue. - - input_shape = model.layers[0].input_shape - - if type(input_shape) == list: - input_shape = input_shape[0][1:3] - else: - input_shape = input_shape[1:3] + #input_shape = model.layers[0].input_shape[1:3] #my environment returns (None, 224, 224, 3) but some people mentioned that they got [(None, 224, 224, 3)]. I think this is because of version issue. - #------------------------ + input_shape = model.layers[0].input_shape + + if type(input_shape) == list: + input_shape = input_shape[0][1:3] + else: + input_shape = input_shape[1:3] input_shape_x = input_shape[0]; input_shape_y = input_shape[1] + #------------------------ + img = functions.preprocess_face(img = img_path, target_size = (input_shape_y, input_shape_x), enforce_detection = enforce_detection, detector_backend = detector_backend) target_representation = model.predict(img)[0,:] diff --git a/deepface/basemodels/DlibResNet.py b/deepface/basemodels/DlibResNet.py index 8e1c188..1a859db 100644 --- a/deepface/basemodels/DlibResNet.py +++ b/deepface/basemodels/DlibResNet.py @@ -9,6 +9,10 @@ from pathlib import Path class DlibResNet: def __init__(self): + + self.layers = [DlibMetaData()] + + #--------------------- home = str(Path.home()) weight_file = home+'/.deepface/weights/dlib_face_recognition_resnet_model_v1.dat' @@ -59,4 +63,8 @@ class DlibResNet: img_representation = np.array(img_representation) img_representation = np.expand_dims(img_representation, axis = 0) - return img_representation \ No newline at end of file + return img_representation + +class DlibMetaData: + def __init__(self): + self.input_shape = [[1, 150, 150, 3]] \ No newline at end of file