enforce tf 2.12 or less

This commit is contained in:
Sefik Ilkin Serengil 2024-03-22 10:02:52 +00:00
parent 47363c6efd
commit 0c8e869371

View File

@ -23,7 +23,6 @@ if tf_major == 1:
Flatten, Flatten,
Dense, Dense,
Dropout, Dropout,
LocallyConnected2D,
) )
else: else:
from tensorflow.keras.models import Model, Sequential from tensorflow.keras.models import Model, Sequential
@ -33,7 +32,6 @@ else:
Flatten, Flatten,
Dense, Dense,
Dropout, Dropout,
LocallyConnected2D,
) )
@ -45,6 +43,14 @@ class DeepFaceClient(FacialRecognition):
""" """
def __init__(self): def __init__(self):
# DeepFace requires tf 2.12 or less
if tf_major == 2 and tf_minor > 12:
# Ref: https://github.com/serengil/deepface/pull/1079
raise ValueError(
"DeepFace model requires LocallyConnected2D but it is no longer supported"
f" after tf 2.12 but you have {tf_major}.{tf_minor}. You need to downgrade your tf."
)
self.model = load_model() self.model = load_model()
self.model_name = "DeepFace" self.model_name = "DeepFace"
self.input_shape = (152, 152) self.input_shape = (152, 152)
@ -69,6 +75,13 @@ def load_model(
""" """
Construct DeepFace model, download its weights and load Construct DeepFace model, download its weights and load
""" """
# we have some checks for this dependency in the init of client
# putting this in global causes library initialization
if tf_major == 1:
from keras.layers import LocallyConnected2D
else:
from tensorflow.keras.layers import LocallyConnected2D
base_model = Sequential() base_model = Sequential()
base_model.add( base_model.add(
Convolution2D(32, (11, 11), activation="relu", name="C1", input_shape=(152, 152, 3)) Convolution2D(32, (11, 11), activation="relu", name="C1", input_shape=(152, 152, 3))