Merge pull request #1327 from serengil/feat-task-3108-exception-handling-for-loading-weights

load weights is done in a common function
This commit is contained in:
Sefik Ilkin Serengil 2024-08-31 16:56:09 +01:00 committed by GitHub
commit 46fe4a8164
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 61 additions and 13 deletions

View File

@ -8,9 +8,15 @@ import bz2
import gdown import gdown
# project dependencies # project dependencies
from deepface.commons import folder_utils from deepface.commons import folder_utils, package_utils
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
tf_version = package_utils.get_tf_major_version()
if tf_version == 1:
from keras.models import Sequential
else:
from tensorflow.keras.models import Sequential
logger = Logger() logger = Logger()
@ -63,3 +69,24 @@ def download_weights_if_necessary(
logger.info(f"{target_file}.bz2 unzipped") logger.info(f"{target_file}.bz2 unzipped")
return target_file return target_file
def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
"""
Load pre-trained weights for a given model
Args:
model (keras.models.Sequential): pre-built model
weight_file (str): exact path of pre-trained weights
Returns:
model (keras.models.Sequential): pre-built model with
updated weights
"""
try:
model.load_weights(weight_file)
except Exception as err:
raise ValueError(
f"Exception while loading pre-trained weights from {weight_file}."
"Possible reason is broken file during downloading weights."
"You may consider to delete it manually."
) from err
return model

View File

@ -69,7 +69,10 @@ def load_model(
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="age_model_weights.h5", source_url=url file_name="age_model_weights.h5", source_url=url
) )
age_model.load_weights(weight_file)
age_model = weight_utils.load_model_weights(
model=age_model, weight_file=weight_file
)
return age_model return age_model

View File

@ -96,6 +96,8 @@ def load_model(
file_name="facial_expression_model_weights.h5", source_url=url file_name="facial_expression_model_weights.h5", source_url=url
) )
model.load_weights(weight_file) model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model return model

View File

@ -72,6 +72,8 @@ def load_model(
file_name="gender_model_weights.h5", source_url=url file_name="gender_model_weights.h5", source_url=url
) )
gender_model.load_weights(weight_file) gender_model = weight_utils.load_model_weights(
model=gender_model, weight_file=weight_file
)
return gender_model return gender_model

View File

@ -69,6 +69,8 @@ def load_model(
file_name="race_model_single_batch.h5", source_url=url file_name="race_model_single_batch.h5", source_url=url
) )
race_model.load_weights(weight_file) race_model = weight_utils.load_model_weights(
model=race_model, weight_file=weight_file
)
return race_model return race_model

View File

@ -82,7 +82,7 @@ def load_model(
file_name="arcface_weights.h5", source_url=url file_name="arcface_weights.h5", source_url=url
) )
model.load_weights(weight_file) model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
# --------------------------------------- # ---------------------------------------
return model return model

View File

@ -89,6 +89,8 @@ def load_model(
file_name="deepid_keras_weights.h5", source_url=url file_name="deepid_keras_weights.h5", source_url=url
) )
model.load_weights(weight_file) model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model return model

View File

@ -1668,7 +1668,9 @@ def load_facenet128d_model(
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="facenet_weights.h5", source_url=url file_name="facenet_weights.h5", source_url=url
) )
model.load_weights(weight_file) model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model return model
@ -1687,6 +1689,8 @@ def load_facenet512d_model(
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="facenet512_weights.h5", source_url=url file_name="facenet512_weights.h5", source_url=url
) )
model.load_weights(weight_file) model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model return model

View File

@ -86,7 +86,7 @@ def load_model(
file_name="VGGFace2_DeepFace_weights_val-0.9034.h5", source_url=url, compress_type="zip" file_name="VGGFace2_DeepFace_weights_val-0.9034.h5", source_url=url, compress_type="zip"
) )
base_model.load_weights(weight_file) base_model = weight_utils.load_model_weights(model=base_model, weight_file=weight_file)
# drop F8 and D0. F7 is the representation layer. # drop F8 and D0. F7 is the representation layer.
deepface_model = Model(inputs=base_model.layers[0].input, outputs=base_model.layers[-3].output) deepface_model = Model(inputs=base_model.layers[0].input, outputs=base_model.layers[-3].output)

View File

@ -74,7 +74,9 @@ def load_model():
file_name="ghostfacenet_v1.h5", source_url=PRETRAINED_WEIGHTS file_name="ghostfacenet_v1.h5", source_url=PRETRAINED_WEIGHTS
) )
model.load_weights(weight_file) model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model return model

View File

@ -385,7 +385,9 @@ def load_model(
file_name="openface_weights.h5", source_url=url file_name="openface_weights.h5", source_url=url
) )
model.load_weights(weight_file) model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
# ----------------------------------- # -----------------------------------

View File

@ -140,7 +140,9 @@ def load_model(
file_name="vgg_face_weights.h5", source_url=url file_name="vgg_face_weights.h5", source_url=url
) )
model.load_weights(weight_file) model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
# 2622d dimensional model # 2622d dimensional model
# vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output) # vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)