From 8b86f3639091ffee50c361fe64138d0baac0a6f9 Mon Sep 17 00:00:00 2001 From: Sefik Ilkin Serengil Date: Sat, 31 Aug 2024 16:46:44 +0100 Subject: [PATCH] load weights is done in a common function --- deepface/commons/weight_utils.py | 29 ++++++++++++++++++- deepface/models/demography/Age.py | 5 +++- deepface/models/demography/Emotion.py | 4 ++- deepface/models/demography/Gender.py | 4 ++- deepface/models/demography/Race.py | 4 ++- deepface/models/facial_recognition/ArcFace.py | 2 +- deepface/models/facial_recognition/DeepID.py | 4 ++- deepface/models/facial_recognition/Facenet.py | 8 +++-- .../models/facial_recognition/FbDeepFace.py | 2 +- .../models/facial_recognition/GhostFaceNet.py | 4 ++- .../models/facial_recognition/OpenFace.py | 4 ++- deepface/models/facial_recognition/VGGFace.py | 4 ++- 12 files changed, 61 insertions(+), 13 deletions(-) diff --git a/deepface/commons/weight_utils.py b/deepface/commons/weight_utils.py index 2c2508e..cbac658 100644 --- a/deepface/commons/weight_utils.py +++ b/deepface/commons/weight_utils.py @@ -8,9 +8,15 @@ import bz2 import gdown # project dependencies -from deepface.commons import folder_utils +from deepface.commons import folder_utils, package_utils from deepface.commons.logger import Logger +tf_version = package_utils.get_tf_major_version() +if tf_version == 1: + from keras.models import Sequential +else: + from tensorflow.keras.models import Sequential + logger = Logger() @@ -63,3 +69,24 @@ def download_weights_if_necessary( logger.info(f"{target_file}.bz2 unzipped") return target_file + + +def load_model_weights(model: Sequential, weight_file: str) -> Sequential: + """ + Load pre-trained weights for a given model + Args: + model (keras.models.Sequential): pre-built model + weight_file (str): exact path of pre-trained weights + Returns: + model (keras.models.Sequential): pre-built model with + updated weights + """ + try: + model.load_weights(weight_file) + except Exception as err: + raise ValueError( + f"Exception while loading pre-trained weights from {weight_file}." + "Possible reason is broken file during downloading weights." + "You may consider to delete it manually." + ) from err + return model diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index e7a6a39..29efdf5 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -69,7 +69,10 @@ def load_model( weight_file = weight_utils.download_weights_if_necessary( file_name="age_model_weights.h5", source_url=url ) - age_model.load_weights(weight_file) + + age_model = weight_utils.load_model_weights( + model=age_model, weight_file=weight_file + ) return age_model diff --git a/deepface/models/demography/Emotion.py b/deepface/models/demography/Emotion.py index 7dcb95e..3d1d88f 100644 --- a/deepface/models/demography/Emotion.py +++ b/deepface/models/demography/Emotion.py @@ -96,6 +96,8 @@ def load_model( file_name="facial_expression_model_weights.h5", source_url=url ) - model.load_weights(weight_file) + model = weight_utils.load_model_weights( + model=model, weight_file=weight_file + ) return model diff --git a/deepface/models/demography/Gender.py b/deepface/models/demography/Gender.py index f682434..2f3a142 100644 --- a/deepface/models/demography/Gender.py +++ b/deepface/models/demography/Gender.py @@ -72,6 +72,8 @@ def load_model( file_name="gender_model_weights.h5", source_url=url ) - gender_model.load_weights(weight_file) + gender_model = weight_utils.load_model_weights( + model=gender_model, weight_file=weight_file + ) return gender_model diff --git a/deepface/models/demography/Race.py b/deepface/models/demography/Race.py index aaf9564..a393667 100644 --- a/deepface/models/demography/Race.py +++ b/deepface/models/demography/Race.py @@ -69,6 +69,8 @@ def load_model( file_name="race_model_single_batch.h5", source_url=url ) - race_model.load_weights(weight_file) + race_model = weight_utils.load_model_weights( + model=race_model, weight_file=weight_file + ) return race_model diff --git a/deepface/models/facial_recognition/ArcFace.py b/deepface/models/facial_recognition/ArcFace.py index 6eda686..596192f 100644 --- a/deepface/models/facial_recognition/ArcFace.py +++ b/deepface/models/facial_recognition/ArcFace.py @@ -82,7 +82,7 @@ def load_model( file_name="arcface_weights.h5", source_url=url ) - model.load_weights(weight_file) + model = weight_utils.load_model_weights(model=model, weight_file=weight_file) # --------------------------------------- return model diff --git a/deepface/models/facial_recognition/DeepID.py b/deepface/models/facial_recognition/DeepID.py index e8217d4..ea03b4e 100644 --- a/deepface/models/facial_recognition/DeepID.py +++ b/deepface/models/facial_recognition/DeepID.py @@ -89,6 +89,8 @@ def load_model( file_name="deepid_keras_weights.h5", source_url=url ) - model.load_weights(weight_file) + model = weight_utils.load_model_weights( + model=model, weight_file=weight_file + ) return model diff --git a/deepface/models/facial_recognition/Facenet.py b/deepface/models/facial_recognition/Facenet.py index f1be068..b1ad37c 100644 --- a/deepface/models/facial_recognition/Facenet.py +++ b/deepface/models/facial_recognition/Facenet.py @@ -1668,7 +1668,9 @@ def load_facenet128d_model( weight_file = weight_utils.download_weights_if_necessary( file_name="facenet_weights.h5", source_url=url ) - model.load_weights(weight_file) + model = weight_utils.load_model_weights( + model=model, weight_file=weight_file + ) return model @@ -1687,6 +1689,8 @@ def load_facenet512d_model( weight_file = weight_utils.download_weights_if_necessary( file_name="facenet512_weights.h5", source_url=url ) - model.load_weights(weight_file) + model = weight_utils.load_model_weights( + model=model, weight_file=weight_file + ) return model diff --git a/deepface/models/facial_recognition/FbDeepFace.py b/deepface/models/facial_recognition/FbDeepFace.py index b53b393..fb41d62 100644 --- a/deepface/models/facial_recognition/FbDeepFace.py +++ b/deepface/models/facial_recognition/FbDeepFace.py @@ -86,7 +86,7 @@ def load_model( file_name="VGGFace2_DeepFace_weights_val-0.9034.h5", source_url=url, compress_type="zip" ) - base_model.load_weights(weight_file) + base_model = weight_utils.load_model_weights(model=base_model, weight_file=weight_file) # drop F8 and D0. F7 is the representation layer. deepface_model = Model(inputs=base_model.layers[0].input, outputs=base_model.layers[-3].output) diff --git a/deepface/models/facial_recognition/GhostFaceNet.py b/deepface/models/facial_recognition/GhostFaceNet.py index 2bb9623..37bd728 100644 --- a/deepface/models/facial_recognition/GhostFaceNet.py +++ b/deepface/models/facial_recognition/GhostFaceNet.py @@ -74,7 +74,9 @@ def load_model(): file_name="ghostfacenet_v1.h5", source_url=PRETRAINED_WEIGHTS ) - model.load_weights(weight_file) + model = weight_utils.load_model_weights( + model=model, weight_file=weight_file + ) return model diff --git a/deepface/models/facial_recognition/OpenFace.py b/deepface/models/facial_recognition/OpenFace.py index 5163c95..c9c1b7a 100644 --- a/deepface/models/facial_recognition/OpenFace.py +++ b/deepface/models/facial_recognition/OpenFace.py @@ -385,7 +385,9 @@ def load_model( file_name="openface_weights.h5", source_url=url ) - model.load_weights(weight_file) + model = weight_utils.load_model_weights( + model=model, weight_file=weight_file + ) # ----------------------------------- diff --git a/deepface/models/facial_recognition/VGGFace.py b/deepface/models/facial_recognition/VGGFace.py index f307ad7..56c8a54 100644 --- a/deepface/models/facial_recognition/VGGFace.py +++ b/deepface/models/facial_recognition/VGGFace.py @@ -140,7 +140,9 @@ def load_model( file_name="vgg_face_weights.h5", source_url=url ) - model.load_weights(weight_file) + model = weight_utils.load_model_weights( + model=model, weight_file=weight_file + ) # 2622d dimensional model # vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)