mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 12:05:22 +00:00
load weights is done in a common function
This commit is contained in:
parent
a3088ac903
commit
8b86f36390
@ -8,9 +8,15 @@ import bz2
|
|||||||
import gdown
|
import gdown
|
||||||
|
|
||||||
# project dependencies
|
# project dependencies
|
||||||
from deepface.commons import folder_utils
|
from deepface.commons import folder_utils, package_utils
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
|
tf_version = package_utils.get_tf_major_version()
|
||||||
|
if tf_version == 1:
|
||||||
|
from keras.models import Sequential
|
||||||
|
else:
|
||||||
|
from tensorflow.keras.models import Sequential
|
||||||
|
|
||||||
logger = Logger()
|
logger = Logger()
|
||||||
|
|
||||||
|
|
||||||
@ -63,3 +69,24 @@ def download_weights_if_necessary(
|
|||||||
logger.info(f"{target_file}.bz2 unzipped")
|
logger.info(f"{target_file}.bz2 unzipped")
|
||||||
|
|
||||||
return target_file
|
return target_file
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
|
||||||
|
"""
|
||||||
|
Load pre-trained weights for a given model
|
||||||
|
Args:
|
||||||
|
model (keras.models.Sequential): pre-built model
|
||||||
|
weight_file (str): exact path of pre-trained weights
|
||||||
|
Returns:
|
||||||
|
model (keras.models.Sequential): pre-built model with
|
||||||
|
updated weights
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model.load_weights(weight_file)
|
||||||
|
except Exception as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Exception while loading pre-trained weights from {weight_file}."
|
||||||
|
"Possible reason is broken file during downloading weights."
|
||||||
|
"You may consider to delete it manually."
|
||||||
|
) from err
|
||||||
|
return model
|
||||||
|
@ -69,7 +69,10 @@ def load_model(
|
|||||||
weight_file = weight_utils.download_weights_if_necessary(
|
weight_file = weight_utils.download_weights_if_necessary(
|
||||||
file_name="age_model_weights.h5", source_url=url
|
file_name="age_model_weights.h5", source_url=url
|
||||||
)
|
)
|
||||||
age_model.load_weights(weight_file)
|
|
||||||
|
age_model = weight_utils.load_model_weights(
|
||||||
|
model=age_model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
return age_model
|
return age_model
|
||||||
|
|
||||||
|
@ -96,6 +96,8 @@ def load_model(
|
|||||||
file_name="facial_expression_model_weights.h5", source_url=url
|
file_name="facial_expression_model_weights.h5", source_url=url
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(weight_file)
|
model = weight_utils.load_model_weights(
|
||||||
|
model=model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -72,6 +72,8 @@ def load_model(
|
|||||||
file_name="gender_model_weights.h5", source_url=url
|
file_name="gender_model_weights.h5", source_url=url
|
||||||
)
|
)
|
||||||
|
|
||||||
gender_model.load_weights(weight_file)
|
gender_model = weight_utils.load_model_weights(
|
||||||
|
model=gender_model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
return gender_model
|
return gender_model
|
||||||
|
@ -69,6 +69,8 @@ def load_model(
|
|||||||
file_name="race_model_single_batch.h5", source_url=url
|
file_name="race_model_single_batch.h5", source_url=url
|
||||||
)
|
)
|
||||||
|
|
||||||
race_model.load_weights(weight_file)
|
race_model = weight_utils.load_model_weights(
|
||||||
|
model=race_model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
return race_model
|
return race_model
|
||||||
|
@ -82,7 +82,7 @@ def load_model(
|
|||||||
file_name="arcface_weights.h5", source_url=url
|
file_name="arcface_weights.h5", source_url=url
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(weight_file)
|
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
|
||||||
# ---------------------------------------
|
# ---------------------------------------
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -89,6 +89,8 @@ def load_model(
|
|||||||
file_name="deepid_keras_weights.h5", source_url=url
|
file_name="deepid_keras_weights.h5", source_url=url
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(weight_file)
|
model = weight_utils.load_model_weights(
|
||||||
|
model=model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -1668,7 +1668,9 @@ def load_facenet128d_model(
|
|||||||
weight_file = weight_utils.download_weights_if_necessary(
|
weight_file = weight_utils.download_weights_if_necessary(
|
||||||
file_name="facenet_weights.h5", source_url=url
|
file_name="facenet_weights.h5", source_url=url
|
||||||
)
|
)
|
||||||
model.load_weights(weight_file)
|
model = weight_utils.load_model_weights(
|
||||||
|
model=model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -1687,6 +1689,8 @@ def load_facenet512d_model(
|
|||||||
weight_file = weight_utils.download_weights_if_necessary(
|
weight_file = weight_utils.download_weights_if_necessary(
|
||||||
file_name="facenet512_weights.h5", source_url=url
|
file_name="facenet512_weights.h5", source_url=url
|
||||||
)
|
)
|
||||||
model.load_weights(weight_file)
|
model = weight_utils.load_model_weights(
|
||||||
|
model=model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -86,7 +86,7 @@ def load_model(
|
|||||||
file_name="VGGFace2_DeepFace_weights_val-0.9034.h5", source_url=url, compress_type="zip"
|
file_name="VGGFace2_DeepFace_weights_val-0.9034.h5", source_url=url, compress_type="zip"
|
||||||
)
|
)
|
||||||
|
|
||||||
base_model.load_weights(weight_file)
|
base_model = weight_utils.load_model_weights(model=base_model, weight_file=weight_file)
|
||||||
|
|
||||||
# drop F8 and D0. F7 is the representation layer.
|
# drop F8 and D0. F7 is the representation layer.
|
||||||
deepface_model = Model(inputs=base_model.layers[0].input, outputs=base_model.layers[-3].output)
|
deepface_model = Model(inputs=base_model.layers[0].input, outputs=base_model.layers[-3].output)
|
||||||
|
@ -74,7 +74,9 @@ def load_model():
|
|||||||
file_name="ghostfacenet_v1.h5", source_url=PRETRAINED_WEIGHTS
|
file_name="ghostfacenet_v1.h5", source_url=PRETRAINED_WEIGHTS
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(weight_file)
|
model = weight_utils.load_model_weights(
|
||||||
|
model=model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -385,7 +385,9 @@ def load_model(
|
|||||||
file_name="openface_weights.h5", source_url=url
|
file_name="openface_weights.h5", source_url=url
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(weight_file)
|
model = weight_utils.load_model_weights(
|
||||||
|
model=model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------
|
# -----------------------------------
|
||||||
|
|
||||||
|
@ -140,7 +140,9 @@ def load_model(
|
|||||||
file_name="vgg_face_weights.h5", source_url=url
|
file_name="vgg_face_weights.h5", source_url=url
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(weight_file)
|
model = weight_utils.load_model_weights(
|
||||||
|
model=model, weight_file=weight_file
|
||||||
|
)
|
||||||
|
|
||||||
# 2622d dimensional model
|
# 2622d dimensional model
|
||||||
# vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
|
# vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user