tf requirement for deepface package

This commit is contained in:
Sefik Ilkin Serengil 2024-03-09 22:34:55 +00:00
parent cdb0fa0b95
commit 42e911958b

View File

@ -12,13 +12,13 @@ logger = Logger(module="basemodels.FbDeepFace")
# -------------------------------- # --------------------------------
# dependency configuration # dependency configuration
tf_version = package_utils.get_tf_major_version() tf_major = package_utils.get_tf_major_version()
tf_minor = package_utils.get_tf_minor_version()
if tf_version == 1: if tf_major == 1:
from keras.models import Model, Sequential from keras.models import Model, Sequential
from keras.layers import ( from keras.layers import (
Convolution2D, Convolution2D,
LocallyConnected2D,
MaxPooling2D, MaxPooling2D,
Flatten, Flatten,
Dense, Dense,
@ -28,7 +28,6 @@ else:
from tensorflow.keras.models import Model, Sequential from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import ( from tensorflow.keras.layers import (
Convolution2D, Convolution2D,
LocallyConnected2D,
MaxPooling2D, MaxPooling2D,
Flatten, Flatten,
Dense, Dense,
@ -44,13 +43,12 @@ class DeepFaceClient(FacialRecognition):
""" """
def __init__(self): def __init__(self):
major = package_utils.get_tf_major_version() # DeepFace requires tf 2.12 or less
minor = package_utils.get_tf_minor_version() if tf_major == 2 and tf_minor > 12:
if major == 2 and minor > 12:
# Ref: https://github.com/serengil/deepface/pull/1079 # Ref: https://github.com/serengil/deepface/pull/1079
raise ValueError( raise ValueError(
"DeepFace model requires LocallyConnected2D but it is no longer supported" "DeepFace model requires LocallyConnected2D but it is no longer supported"
f" after tf 2.12 but you have {major}.{minor}. You need to downgrade your tf." f" after tf 2.12 but you have {tf_major}.{tf_minor}. You need to downgrade your tf."
) )
self.model = load_model() self.model = load_model()
@ -77,6 +75,13 @@ def load_model(
""" """
Construct DeepFace model, download its weights and load Construct DeepFace model, download its weights and load
""" """
# we have some checks for this dependency in the init of client
# putting this in global causes library initialization
if tf_major == 1:
from keras.layers import LocallyConnected2D
else:
from tensorflow.keras.layers import LocallyConnected2D
base_model = Sequential() base_model = Sequential()
base_model.add( base_model.add(
Convolution2D(32, (11, 11), activation="relu", name="C1", input_shape=(152, 152, 3)) Convolution2D(32, (11, 11), activation="relu", name="C1", input_shape=(152, 152, 3))