Merge pull request #948 from serengil/feat-task-0701-vgg-descriptor

VGG-Face descriptor with new structure
This commit is contained in:
Sefik Ilkin Serengil 2024-01-08 17:40:21 +00:00 committed by GitHub
commit 1b40870a8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 66 additions and 11 deletions

3
.gitignore vendored
View File

@ -8,5 +8,6 @@ Pipfile.lock
.idea/
deepface.egg-info/
tests/dataset/*.pkl
tests/sandbox.ipynb
tests/*.ipynb
tests/*.csv
*.pyc

View File

@ -90,15 +90,15 @@ Face recognition models basically represent facial images as multi-dimensional v
embedding_objs = DeepFace.represent(img_path = "img.jpg")
```
This function returns an array as embedding. The size of the embedding array would be different based on the model name. For instance, VGG-Face is the default model and it represents facial images as 2622 dimensional vectors.
This function returns an array as embedding. The size of the embedding array would be different based on the model name. For instance, VGG-Face is the default model and it represents facial images as 4096 dimensional vectors.
```python
embedding = embedding_objs[0]["embedding"]
assert isinstance(embedding, list)
assert model_name = "VGG-Face" and len(embedding) == 2622
assert model_name = "VGG-Face" and len(embedding) == 4096
```
Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introduction-to-face-recognition-in-deep-learning/) with 2622 slots horizontally. Each slot is corresponding to a dimension value in the embedding vector and dimension value is explained in the colorbar on the right. Similar to 2D barcodes, vertical dimension stores no information in the illustration.
Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introduction-to-face-recognition-in-deep-learning/) with 4096 slots horizontally. Each slot is corresponding to a dimension value in the embedding vector and dimension value is explained in the colorbar on the right. Similar to 2D barcodes, vertical dimension stores no information in the illustration.
<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/embedding.jpg" width="95%" height="95%"></p>

View File

@ -616,6 +616,15 @@ def find(
for index, instance in df.iterrows():
source_representation = instance[f"{model_name}_representation"]
target_dims = len(list(target_representation))
source_dims = len(list(source_representation))
if target_dims != source_dims:
raise ValueError(
"Source and target embeddings must have same dimensions but "
+ f"{target_dims}:{source_dims}. Model structure may change"
+ " after pickle created. Delete the {file_name} and re-run."
)
if distance_metric == "cosine":
distance = dst.findCosineDistance(source_representation, target_representation)
elif distance_metric == "euclidean":
@ -636,6 +645,7 @@ def find(
threshold = dst.findThreshold(model_name, distance_metric)
result_df = result_df.drop(columns=[f"{model_name}_representation"])
# pylint: disable=unsubscriptable-object
result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold]
result_df = result_df.sort_values(
by=[f"{model_name}_{distance_metric}"], ascending=True

View File

@ -19,7 +19,9 @@ if tf_version == 1:
Flatten,
Dropout,
Activation,
Lambda,
)
from keras import backend as K
else:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import (
@ -29,7 +31,9 @@ else:
Flatten,
Dropout,
Activation,
Lambda,
)
from tensorflow.keras import backend as K
# ---------------------------------------
@ -98,6 +102,18 @@ def loadModel(
model.load_weights(output)
vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
# 2622d dimensional model
# vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
# 4096 dimensional model offers 6% to 14% increasement on accuracy!
# - softmax causes underfitting
# - added normalization layer to avoid underfitting with euclidean
# as described here: https://github.com/serengil/deepface/issues/944
base_model_output = Sequential()
base_model_output = Flatten()(model.layers[-5].output)
base_model_output = Lambda(lambda x: K.l2_normalize(x, axis=1), name="norm_layer")(
base_model_output
)
vgg_face_descriptor = Model(inputs=model.input, outputs=base_model_output)
return vgg_face_descriptor

View File

@ -41,7 +41,12 @@ def findThreshold(model_name: str, distance_metric: str) -> float:
base_threshold = {"cosine": 0.40, "euclidean": 0.55, "euclidean_l2": 0.75}
thresholds = {
"VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86},
# "VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86}, # 2622d
"VGG-Face": {
"cosine": 0.68,
"euclidean": 1.17,
"euclidean_l2": 1.17,
}, # 4096d - tuned with LFW
"Facenet": {"cosine": 0.40, "euclidean": 10, "euclidean_l2": 0.80},
"Facenet512": {"cosine": 0.30, "euclidean": 23.56, "euclidean_l2": 1.04},
"ArcFace": {"cosine": 0.68, "euclidean": 4.15, "euclidean_l2": 1.13},

View File

@ -33,7 +33,7 @@ def test_disabled_enforce_detection_for_non_facial_input_on_represent():
assert "w" in objs[0]["facial_area"].keys()
assert "h" in objs[0]["facial_area"].keys()
assert isinstance(objs[0]["embedding"], list)
assert len(objs[0]["embedding"]) == 2622 # embedding of VGG-Face
assert len(objs[0]["embedding"]) == 4096 # embedding of VGG-Face
logger.info("✅ disabled enforce detection with non facial input test for represent tests done")

View File

@ -7,19 +7,42 @@ logger = Logger("tests/test_find.py")
def test_find_with_exact_path():
dfs = DeepFace.find(img_path="dataset/img1.jpg", db_path="dataset", silent=True)
img_path = "dataset/img1.jpg"
dfs = DeepFace.find(img_path=img_path, db_path="dataset", silent=True)
assert len(dfs) > 0
for df in dfs:
assert isinstance(df, pd.DataFrame)
# one is img1.jpg itself
identity_df = df[df["identity"] == img_path]
assert identity_df.shape[0] > 0
# validate reproducability
assert identity_df["VGG-Face_cosine"].values[0] == 0
df = df[df["identity"] != img_path]
logger.debug(df.head())
assert df.shape[0] > 0
logger.info("✅ test find for exact path done")
def test_find_with_array_input():
img1 = cv2.imread("dataset/img1.jpg")
img_path = "dataset/img1.jpg"
img1 = cv2.imread(img_path)
dfs = DeepFace.find(img1, db_path="dataset", silent=True)
assert len(dfs) > 0
for df in dfs:
assert isinstance(df, pd.DataFrame)
# one is img1.jpg itself
identity_df = df[df["identity"] == img_path]
assert identity_df.shape[0] > 0
# validate reproducability
assert identity_df["VGG-Face_cosine"].values[0] == 0
df = df[df["identity"] != img_path]
logger.debug(df.head())
assert df.shape[0] > 0

View File

@ -10,7 +10,7 @@ def test_standard_represent():
for embedding_obj in embedding_objs:
embedding = embedding_obj["embedding"]
logger.debug(f"Function returned {len(embedding)} dimensional vector")
assert len(embedding) == 2622
assert len(embedding) == 4096
logger.info("✅ test standard represent function done")