load weights is done in a common function

This commit is contained in:
Sefik Ilkin Serengil 2024-08-31 16:46:44 +01:00
parent a3088ac903
commit 8b86f36390
12 changed files with 61 additions and 13 deletions

View File

@ -8,9 +8,15 @@ import bz2
import gdown
# project dependencies
from deepface.commons import folder_utils
from deepface.commons import folder_utils, package_utils
from deepface.commons.logger import Logger
tf_version = package_utils.get_tf_major_version()
if tf_version == 1:
from keras.models import Sequential
else:
from tensorflow.keras.models import Sequential
logger = Logger()
@ -63,3 +69,24 @@ def download_weights_if_necessary(
logger.info(f"{target_file}.bz2 unzipped")
return target_file
def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
"""
Load pre-trained weights for a given model
Args:
model (keras.models.Sequential): pre-built model
weight_file (str): exact path of pre-trained weights
Returns:
model (keras.models.Sequential): pre-built model with
updated weights
"""
try:
model.load_weights(weight_file)
except Exception as err:
raise ValueError(
f"Exception while loading pre-trained weights from {weight_file}."
"Possible reason is broken file during downloading weights."
"You may consider to delete it manually."
) from err
return model

View File

@ -69,7 +69,10 @@ def load_model(
weight_file = weight_utils.download_weights_if_necessary(
file_name="age_model_weights.h5", source_url=url
)
age_model.load_weights(weight_file)
age_model = weight_utils.load_model_weights(
model=age_model, weight_file=weight_file
)
return age_model

View File

@ -96,6 +96,8 @@ def load_model(
file_name="facial_expression_model_weights.h5", source_url=url
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model

View File

@ -72,6 +72,8 @@ def load_model(
file_name="gender_model_weights.h5", source_url=url
)
gender_model.load_weights(weight_file)
gender_model = weight_utils.load_model_weights(
model=gender_model, weight_file=weight_file
)
return gender_model

View File

@ -69,6 +69,8 @@ def load_model(
file_name="race_model_single_batch.h5", source_url=url
)
race_model.load_weights(weight_file)
race_model = weight_utils.load_model_weights(
model=race_model, weight_file=weight_file
)
return race_model

View File

@ -82,7 +82,7 @@ def load_model(
file_name="arcface_weights.h5", source_url=url
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
# ---------------------------------------
return model

View File

@ -89,6 +89,8 @@ def load_model(
file_name="deepid_keras_weights.h5", source_url=url
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model

View File

@ -1668,7 +1668,9 @@ def load_facenet128d_model(
weight_file = weight_utils.download_weights_if_necessary(
file_name="facenet_weights.h5", source_url=url
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model
@ -1687,6 +1689,8 @@ def load_facenet512d_model(
weight_file = weight_utils.download_weights_if_necessary(
file_name="facenet512_weights.h5", source_url=url
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model

View File

@ -86,7 +86,7 @@ def load_model(
file_name="VGGFace2_DeepFace_weights_val-0.9034.h5", source_url=url, compress_type="zip"
)
base_model.load_weights(weight_file)
base_model = weight_utils.load_model_weights(model=base_model, weight_file=weight_file)
# drop F8 and D0. F7 is the representation layer.
deepface_model = Model(inputs=base_model.layers[0].input, outputs=base_model.layers[-3].output)

View File

@ -74,7 +74,9 @@ def load_model():
file_name="ghostfacenet_v1.h5", source_url=PRETRAINED_WEIGHTS
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
return model

View File

@ -385,7 +385,9 @@ def load_model(
file_name="openface_weights.h5", source_url=url
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
# -----------------------------------

View File

@ -140,7 +140,9 @@ def load_model(
file_name="vgg_face_weights.h5", source_url=url
)
model.load_weights(weight_file)
model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
# 2622d dimensional model
# vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)