check tf_keras installation for tf 2.16 or later versions

This commit is contained in:
Sefik Ilkin Serengil 2024-03-19 08:45:10 +00:00
parent 2cbfc417a4
commit 71bea587da
2 changed files with 29 additions and 0 deletions

View File

@ -4,6 +4,11 @@ import warnings
import logging import logging
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Dict, List, Tuple, Union, Optional
# this has to be set before importing tensorflow
os.environ["TF_USE_LEGACY_KERAS"] = "1"
# pylint: disable=wrong-import-position
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -28,6 +33,9 @@ logger = Logger(module="DeepFace")
# ----------------------------------- # -----------------------------------
# configurations for dependencies # configurations for dependencies
# users should install tf_keras package if they are using tf 2.16 or later versions
package_utils.validate_for_keras3()
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf_version = package_utils.get_tf_major_version() tf_version = package_utils.get_tf_major_version()

View File

@ -50,3 +50,24 @@ def find_hash_of_file(file_path: str) -> str:
hasher = hashlib.sha1() hasher = hashlib.sha1()
hasher.update(properties.encode("utf-8")) hasher.update(properties.encode("utf-8"))
return hasher.hexdigest() return hasher.hexdigest()
def validate_for_keras3():
tf_major = get_tf_major_version()
tf_minor = get_tf_minor_version()
# tf_keras is a must dependency after tf 2.16
if tf_major == 1 or (tf_major == 2 and tf_minor < 16):
return
try:
import tf_keras
logger.debug(f"tf_keras is already available - {tf_keras.__version__}")
except ImportError as err:
# you may consider to install that package here
raise ValueError(
f"You have tensorflow {tf.__version__} and this requires "
"tf-keras package. Please run `pip install tf-keras` "
"or downgrade your tensorflow."
) from err