diff --git a/tests/face-recognition-how.py b/tests/face-recognition-how.py index 01988b1..22e48ec 100644 --- a/tests/face-recognition-how.py +++ b/tests/face-recognition-how.py @@ -1,30 +1,28 @@ import matplotlib.pyplot as plt import numpy as np -from deepface.basemodels import VGGFace +from deepface import DeepFace from deepface.commons import functions # ---------------------------------------------- # build face recognition model -model = VGGFace.loadModel() +model_name = "VGG-Face" -try: - input_shape = model.layers[0].input_shape[1:3] -except: # issue 470 - input_shape = model.layers[0].input_shape[0][1:3] +model = DeepFace.build_model(model_name=model_name) -print("model input shape: ", model.layers[0].input_shape[1:]) -print("model output shape: ", model.layers[-1].input_shape[-1]) +target_size = functions.find_target_size(model_name) + +print(f"target_size: {target_size}") # ---------------------------------------------- # load images and find embeddings -# img1 = functions.detectFace("dataset/img1.jpg", input_shape) -img1 = functions.preprocess_face("dataset/img1.jpg", input_shape) +img1 = DeepFace.extract_faces(img_path="dataset/img1.jpg", target_size=target_size)[0]["face"] +img1 = np.expand_dims(img1, axis=0) # to (1, 224, 224, 3) img1_representation = model.predict(img1)[0, :] -# img2 = functions.detectFace("dataset/img3.jpg", input_shape) -img2 = functions.preprocess_face("dataset/img3.jpg", input_shape) +img2 = DeepFace.extract_faces(img_path="dataset/img3.jpg", target_size=target_size)[0]["face"] +img2 = np.expand_dims(img2, axis=0) img2_representation = model.predict(img2)[0, :] # ---------------------------------------------- @@ -58,7 +56,7 @@ distance_graph = np.array(distance_graph) fig = plt.figure() ax1 = fig.add_subplot(3, 2, 1) -plt.imshow(img1[0][:, :, ::-1]) +plt.imshow(img1[0]) plt.axis("off") ax2 = fig.add_subplot(3, 2, 2) @@ -66,7 +64,7 @@ im = plt.imshow(img1_graph, interpolation="nearest", cmap=plt.cm.ocean) plt.colorbar() ax3 = fig.add_subplot(3, 2, 3) -plt.imshow(img2[0][:, :, ::-1]) +plt.imshow(img2[0]) plt.axis("off") ax4 = fig.add_subplot(3, 2, 4)