diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index e9f6d9d..c8920d8 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -4,6 +4,11 @@ import warnings import logging from typing import Any, Dict, List, Tuple, Union, Optional +# this has to be set before importing tensorflow +os.environ["TF_USE_LEGACY_KERAS"] = "1" + +# pylint: disable=wrong-import-position + # 3rd party dependencies import numpy as np import pandas as pd @@ -28,6 +33,9 @@ logger = Logger(module="DeepFace") # ----------------------------------- # configurations for dependencies +# users should install tf_keras package if they are using tf 2.16 or later versions +package_utils.validate_for_keras3() + warnings.filterwarnings("ignore") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf_version = package_utils.get_tf_major_version() diff --git a/deepface/commons/package_utils.py b/deepface/commons/package_utils.py index 20fa35e..1620732 100644 --- a/deepface/commons/package_utils.py +++ b/deepface/commons/package_utils.py @@ -50,3 +50,24 @@ def find_hash_of_file(file_path: str) -> str: hasher = hashlib.sha1() hasher.update(properties.encode("utf-8")) return hasher.hexdigest() + + +def validate_for_keras3(): + tf_major = get_tf_major_version() + tf_minor = get_tf_minor_version() + + # tf_keras is a must dependency after tf 2.16 + if tf_major == 1 or (tf_major == 2 and tf_minor < 16): + return + + try: + import tf_keras + + logger.debug(f"tf_keras is already available - {tf_keras.__version__}") + except ImportError as err: + # you may consider to install that package here + raise ValueError( + f"You have tensorflow {tf.__version__} and this requires " + "tf-keras package. Please run `pip install tf-keras` " + "or downgrade your tensorflow." + ) from err