deepface/deepface/commons/weight_utils.py
2024-09-01 09:37:29 +01:00

93 lines
3.1 KiB
Python

# built-in dependencies
import os
from typing import Optional
import zipfile
import bz2
# 3rd party dependencies
import gdown
# project dependencies
from deepface.commons import folder_utils, package_utils
from deepface.commons.logger import Logger
tf_version = package_utils.get_tf_major_version()
if tf_version == 1:
from keras.models import Sequential
else:
from tensorflow.keras.models import Sequential
logger = Logger()
def download_weights_if_necessary(
file_name: str, source_url: str, compress_type: Optional[str] = None
) -> str:
"""
Download the weights of a pre-trained model from external source if not downloaded yet.
Args:
file_name (str): target file name with extension
source_url (url): source url to be downloaded
compress_type (optional str): compress type e.g. zip or bz2
Returns
target_file (str): exact path for the target file
"""
home = folder_utils.get_deepface_home()
target_file = os.path.join(home, ".deepface/weights", file_name)
if os.path.isfile(target_file):
logger.debug(f"{file_name} is already available at {target_file}")
return target_file
try:
logger.info(f"🔗 {file_name} will be downloaded from {source_url} to {target_file}...")
if compress_type is None:
gdown.download(source_url, target_file, quiet=False)
elif compress_type is not None:
gdown.download(source_url, f"{target_file}.{compress_type}", quiet=False)
except Exception as err:
raise ValueError(
f"⛓️‍💥 An exception occurred while downloading {file_name} from {source_url}. "
f"Consider downloading it manually to {target_file}."
) from err
# uncompress downloaded file
if compress_type == "zip":
with zipfile.ZipFile(f"{target_file}.zip", "r") as zip_ref:
zip_ref.extractall(os.path.join(home, ".deepface/weights"))
logger.info(f"{target_file}.zip unzipped")
elif compress_type == "bz2":
bz2file = bz2.BZ2File(f"{target_file}.bz2")
data = bz2file.read()
with open(target_file, "wb") as f:
f.write(data)
logger.info(f"{target_file}.bz2 unzipped")
return target_file
def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
"""
Load pre-trained weights for a given model
Args:
model (keras.models.Sequential): pre-built model
weight_file (str): exact path of pre-trained weights
Returns:
model (keras.models.Sequential): pre-built model with
updated weights
"""
try:
model.load_weights(weight_file)
except Exception as err:
raise ValueError(
f"An exception occurred while loading the pre-trained weights from {weight_file}."
"This might have happened due to an interruption during the download."
"You may want to delete it and allow DeepFace to download it again during the next run."
"If the issue persists, consider downloading the file directly from the source "
"and copying it to the target folder."
) from err
return model