mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
load weights is done in a common function
This commit is contained in:
parent
a3088ac903
commit
8b86f36390
@ -8,9 +8,15 @@ import bz2
|
||||
import gdown
|
||||
|
||||
# project dependencies
|
||||
from deepface.commons import folder_utils
|
||||
from deepface.commons import folder_utils, package_utils
|
||||
from deepface.commons.logger import Logger
|
||||
|
||||
tf_version = package_utils.get_tf_major_version()
|
||||
if tf_version == 1:
|
||||
from keras.models import Sequential
|
||||
else:
|
||||
from tensorflow.keras.models import Sequential
|
||||
|
||||
logger = Logger()
|
||||
|
||||
|
||||
@ -63,3 +69,24 @@ def download_weights_if_necessary(
|
||||
logger.info(f"{target_file}.bz2 unzipped")
|
||||
|
||||
return target_file
|
||||
|
||||
|
||||
def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
|
||||
"""
|
||||
Load pre-trained weights for a given model
|
||||
Args:
|
||||
model (keras.models.Sequential): pre-built model
|
||||
weight_file (str): exact path of pre-trained weights
|
||||
Returns:
|
||||
model (keras.models.Sequential): pre-built model with
|
||||
updated weights
|
||||
"""
|
||||
try:
|
||||
model.load_weights(weight_file)
|
||||
except Exception as err:
|
||||
raise ValueError(
|
||||
f"Exception while loading pre-trained weights from {weight_file}."
|
||||
"Possible reason is broken file during downloading weights."
|
||||
"You may consider to delete it manually."
|
||||
) from err
|
||||
return model
|
||||
|
@ -69,7 +69,10 @@ def load_model(
|
||||
weight_file = weight_utils.download_weights_if_necessary(
|
||||
file_name="age_model_weights.h5", source_url=url
|
||||
)
|
||||
age_model.load_weights(weight_file)
|
||||
|
||||
age_model = weight_utils.load_model_weights(
|
||||
model=age_model, weight_file=weight_file
|
||||
)
|
||||
|
||||
return age_model
|
||||
|
||||
|
@ -96,6 +96,8 @@ def load_model(
|
||||
file_name="facial_expression_model_weights.h5", source_url=url
|
||||
)
|
||||
|
||||
model.load_weights(weight_file)
|
||||
model = weight_utils.load_model_weights(
|
||||
model=model, weight_file=weight_file
|
||||
)
|
||||
|
||||
return model
|
||||
|
@ -72,6 +72,8 @@ def load_model(
|
||||
file_name="gender_model_weights.h5", source_url=url
|
||||
)
|
||||
|
||||
gender_model.load_weights(weight_file)
|
||||
gender_model = weight_utils.load_model_weights(
|
||||
model=gender_model, weight_file=weight_file
|
||||
)
|
||||
|
||||
return gender_model
|
||||
|
@ -69,6 +69,8 @@ def load_model(
|
||||
file_name="race_model_single_batch.h5", source_url=url
|
||||
)
|
||||
|
||||
race_model.load_weights(weight_file)
|
||||
race_model = weight_utils.load_model_weights(
|
||||
model=race_model, weight_file=weight_file
|
||||
)
|
||||
|
||||
return race_model
|
||||
|
@ -82,7 +82,7 @@ def load_model(
|
||||
file_name="arcface_weights.h5", source_url=url
|
||||
)
|
||||
|
||||
model.load_weights(weight_file)
|
||||
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
|
||||
# ---------------------------------------
|
||||
|
||||
return model
|
||||
|
@ -89,6 +89,8 @@ def load_model(
|
||||
file_name="deepid_keras_weights.h5", source_url=url
|
||||
)
|
||||
|
||||
model.load_weights(weight_file)
|
||||
model = weight_utils.load_model_weights(
|
||||
model=model, weight_file=weight_file
|
||||
)
|
||||
|
||||
return model
|
||||
|
@ -1668,7 +1668,9 @@ def load_facenet128d_model(
|
||||
weight_file = weight_utils.download_weights_if_necessary(
|
||||
file_name="facenet_weights.h5", source_url=url
|
||||
)
|
||||
model.load_weights(weight_file)
|
||||
model = weight_utils.load_model_weights(
|
||||
model=model, weight_file=weight_file
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@ -1687,6 +1689,8 @@ def load_facenet512d_model(
|
||||
weight_file = weight_utils.download_weights_if_necessary(
|
||||
file_name="facenet512_weights.h5", source_url=url
|
||||
)
|
||||
model.load_weights(weight_file)
|
||||
model = weight_utils.load_model_weights(
|
||||
model=model, weight_file=weight_file
|
||||
)
|
||||
|
||||
return model
|
||||
|
@ -86,7 +86,7 @@ def load_model(
|
||||
file_name="VGGFace2_DeepFace_weights_val-0.9034.h5", source_url=url, compress_type="zip"
|
||||
)
|
||||
|
||||
base_model.load_weights(weight_file)
|
||||
base_model = weight_utils.load_model_weights(model=base_model, weight_file=weight_file)
|
||||
|
||||
# drop F8 and D0. F7 is the representation layer.
|
||||
deepface_model = Model(inputs=base_model.layers[0].input, outputs=base_model.layers[-3].output)
|
||||
|
@ -74,7 +74,9 @@ def load_model():
|
||||
file_name="ghostfacenet_v1.h5", source_url=PRETRAINED_WEIGHTS
|
||||
)
|
||||
|
||||
model.load_weights(weight_file)
|
||||
model = weight_utils.load_model_weights(
|
||||
model=model, weight_file=weight_file
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -385,7 +385,9 @@ def load_model(
|
||||
file_name="openface_weights.h5", source_url=url
|
||||
)
|
||||
|
||||
model.load_weights(weight_file)
|
||||
model = weight_utils.load_model_weights(
|
||||
model=model, weight_file=weight_file
|
||||
)
|
||||
|
||||
# -----------------------------------
|
||||
|
||||
|
@ -140,7 +140,9 @@ def load_model(
|
||||
file_name="vgg_face_weights.h5", source_url=url
|
||||
)
|
||||
|
||||
model.load_weights(weight_file)
|
||||
model = weight_utils.load_model_weights(
|
||||
model=model, weight_file=weight_file
|
||||
)
|
||||
|
||||
# 2622d dimensional model
|
||||
# vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
|
||||
|
Loading…
x
Reference in New Issue
Block a user