mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 19:45:21 +00:00
enforce tf 2.12 or less
This commit is contained in:
parent
47363c6efd
commit
0c8e869371
@ -23,7 +23,6 @@ if tf_major == 1:
|
||||
Flatten,
|
||||
Dense,
|
||||
Dropout,
|
||||
LocallyConnected2D,
|
||||
)
|
||||
else:
|
||||
from tensorflow.keras.models import Model, Sequential
|
||||
@ -33,7 +32,6 @@ else:
|
||||
Flatten,
|
||||
Dense,
|
||||
Dropout,
|
||||
LocallyConnected2D,
|
||||
)
|
||||
|
||||
|
||||
@ -45,6 +43,14 @@ class DeepFaceClient(FacialRecognition):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# DeepFace requires tf 2.12 or less
|
||||
if tf_major == 2 and tf_minor > 12:
|
||||
# Ref: https://github.com/serengil/deepface/pull/1079
|
||||
raise ValueError(
|
||||
"DeepFace model requires LocallyConnected2D but it is no longer supported"
|
||||
f" after tf 2.12 but you have {tf_major}.{tf_minor}. You need to downgrade your tf."
|
||||
)
|
||||
|
||||
self.model = load_model()
|
||||
self.model_name = "DeepFace"
|
||||
self.input_shape = (152, 152)
|
||||
@ -69,6 +75,13 @@ def load_model(
|
||||
"""
|
||||
Construct DeepFace model, download its weights and load
|
||||
"""
|
||||
# we have some checks for this dependency in the init of client
|
||||
# putting this in global causes library initialization
|
||||
if tf_major == 1:
|
||||
from keras.layers import LocallyConnected2D
|
||||
else:
|
||||
from tensorflow.keras.layers import LocallyConnected2D
|
||||
|
||||
base_model = Sequential()
|
||||
base_model.add(
|
||||
Convolution2D(32, (11, 11), activation="relu", name="C1", input_shape=(152, 152, 3))
|
||||
|
Loading…
x
Reference in New Issue
Block a user