diff --git a/.gitignore b/.gitignore
index 02f9bbe..63ebe2c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,5 +8,6 @@ Pipfile.lock
.idea/
deepface.egg-info/
tests/dataset/*.pkl
-tests/sandbox.ipynb
+tests/*.ipynb
+tests/*.csv
*.pyc
diff --git a/README.md b/README.md
index b136e6b..a4549c3 100644
--- a/README.md
+++ b/README.md
@@ -90,15 +90,15 @@ Face recognition models basically represent facial images as multi-dimensional v
embedding_objs = DeepFace.represent(img_path = "img.jpg")
```
-This function returns an array as embedding. The size of the embedding array would be different based on the model name. For instance, VGG-Face is the default model and it represents facial images as 2622 dimensional vectors.
+This function returns an array as embedding. The size of the embedding array would be different based on the model name. For instance, VGG-Face is the default model and it represents facial images as 4096 dimensional vectors.
```python
embedding = embedding_objs[0]["embedding"]
assert isinstance(embedding, list)
-assert model_name = "VGG-Face" and len(embedding) == 2622
+assert model_name = "VGG-Face" and len(embedding) == 4096
```
-Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introduction-to-face-recognition-in-deep-learning/) with 2622 slots horizontally. Each slot is corresponding to a dimension value in the embedding vector and dimension value is explained in the colorbar on the right. Similar to 2D barcodes, vertical dimension stores no information in the illustration.
+Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introduction-to-face-recognition-in-deep-learning/) with 4096 slots horizontally. Each slot is corresponding to a dimension value in the embedding vector and dimension value is explained in the colorbar on the right. Similar to 2D barcodes, vertical dimension stores no information in the illustration.

diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py
index 7a83587..4a157d3 100644
--- a/deepface/DeepFace.py
+++ b/deepface/DeepFace.py
@@ -616,6 +616,15 @@ def find(
for index, instance in df.iterrows():
source_representation = instance[f"{model_name}_representation"]
+ target_dims = len(list(target_representation))
+ source_dims = len(list(source_representation))
+ if target_dims != source_dims:
+ raise ValueError(
+ "Source and target embeddings must have same dimensions but "
+ + f"{target_dims}:{source_dims}. Model structure may change"
+ + " after pickle created. Delete the {file_name} and re-run."
+ )
+
if distance_metric == "cosine":
distance = dst.findCosineDistance(source_representation, target_representation)
elif distance_metric == "euclidean":
@@ -636,6 +645,7 @@ def find(
threshold = dst.findThreshold(model_name, distance_metric)
result_df = result_df.drop(columns=[f"{model_name}_representation"])
+ # pylint: disable=unsubscriptable-object
result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold]
result_df = result_df.sort_values(
by=[f"{model_name}_{distance_metric}"], ascending=True
diff --git a/deepface/basemodels/VGGFace.py b/deepface/basemodels/VGGFace.py
index d02909d..a149425 100644
--- a/deepface/basemodels/VGGFace.py
+++ b/deepface/basemodels/VGGFace.py
@@ -19,7 +19,9 @@ if tf_version == 1:
Flatten,
Dropout,
Activation,
+ Lambda,
)
+ from keras import backend as K
else:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import (
@@ -29,7 +31,9 @@ else:
Flatten,
Dropout,
Activation,
+ Lambda,
)
+ from tensorflow.keras import backend as K
# ---------------------------------------
@@ -98,6 +102,18 @@ def loadModel(
model.load_weights(output)
- vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
+ # 2622d dimensional model
+ # vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
+
+ # 4096 dimensional model offers 6% to 14% increasement on accuracy!
+ # - softmax causes underfitting
+ # - added normalization layer to avoid underfitting with euclidean
+ # as described here: https://github.com/serengil/deepface/issues/944
+ base_model_output = Sequential()
+ base_model_output = Flatten()(model.layers[-5].output)
+ base_model_output = Lambda(lambda x: K.l2_normalize(x, axis=1), name="norm_layer")(
+ base_model_output
+ )
+ vgg_face_descriptor = Model(inputs=model.input, outputs=base_model_output)
return vgg_face_descriptor
diff --git a/deepface/commons/distance.py b/deepface/commons/distance.py
index 4ed2247..1744736 100644
--- a/deepface/commons/distance.py
+++ b/deepface/commons/distance.py
@@ -41,7 +41,12 @@ def findThreshold(model_name: str, distance_metric: str) -> float:
base_threshold = {"cosine": 0.40, "euclidean": 0.55, "euclidean_l2": 0.75}
thresholds = {
- "VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86},
+ # "VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86}, # 2622d
+ "VGG-Face": {
+ "cosine": 0.68,
+ "euclidean": 1.17,
+ "euclidean_l2": 1.17,
+ }, # 4096d - tuned with LFW
"Facenet": {"cosine": 0.40, "euclidean": 10, "euclidean_l2": 0.80},
"Facenet512": {"cosine": 0.30, "euclidean": 23.56, "euclidean_l2": 1.04},
"ArcFace": {"cosine": 0.68, "euclidean": 4.15, "euclidean_l2": 1.13},
diff --git a/tests/test_enforce_detection.py b/tests/test_enforce_detection.py
index 7fa281d..74c4704 100644
--- a/tests/test_enforce_detection.py
+++ b/tests/test_enforce_detection.py
@@ -33,7 +33,7 @@ def test_disabled_enforce_detection_for_non_facial_input_on_represent():
assert "w" in objs[0]["facial_area"].keys()
assert "h" in objs[0]["facial_area"].keys()
assert isinstance(objs[0]["embedding"], list)
- assert len(objs[0]["embedding"]) == 2622 # embedding of VGG-Face
+ assert len(objs[0]["embedding"]) == 4096 # embedding of VGG-Face
logger.info("✅ disabled enforce detection with non facial input test for represent tests done")
diff --git a/tests/test_find.py b/tests/test_find.py
index aefe98e..423567b 100644
--- a/tests/test_find.py
+++ b/tests/test_find.py
@@ -7,19 +7,42 @@ logger = Logger("tests/test_find.py")
def test_find_with_exact_path():
- dfs = DeepFace.find(img_path="dataset/img1.jpg", db_path="dataset", silent=True)
+ img_path = "dataset/img1.jpg"
+ dfs = DeepFace.find(img_path=img_path, db_path="dataset", silent=True)
+ assert len(dfs) > 0
for df in dfs:
assert isinstance(df, pd.DataFrame)
+
+ # one is img1.jpg itself
+ identity_df = df[df["identity"] == img_path]
+ assert identity_df.shape[0] > 0
+
+ # validate reproducability
+ assert identity_df["VGG-Face_cosine"].values[0] == 0
+
+ df = df[df["identity"] != img_path]
logger.debug(df.head())
assert df.shape[0] > 0
logger.info("✅ test find for exact path done")
def test_find_with_array_input():
- img1 = cv2.imread("dataset/img1.jpg")
+ img_path = "dataset/img1.jpg"
+ img1 = cv2.imread(img_path)
dfs = DeepFace.find(img1, db_path="dataset", silent=True)
-
+ assert len(dfs) > 0
for df in dfs:
+ assert isinstance(df, pd.DataFrame)
+
+ # one is img1.jpg itself
+ identity_df = df[df["identity"] == img_path]
+ assert identity_df.shape[0] > 0
+
+ # validate reproducability
+ assert identity_df["VGG-Face_cosine"].values[0] == 0
+
+
+ df = df[df["identity"] != img_path]
logger.debug(df.head())
assert df.shape[0] > 0
diff --git a/tests/test_represent.py b/tests/test_represent.py
index 2dd68ea..4b45594 100644
--- a/tests/test_represent.py
+++ b/tests/test_represent.py
@@ -10,7 +10,7 @@ def test_standard_represent():
for embedding_obj in embedding_objs:
embedding = embedding_obj["embedding"]
logger.debug(f"Function returned {len(embedding)} dimensional vector")
- assert len(embedding) == 2622
+ assert len(embedding) == 4096
logger.info("✅ test standard represent function done")