mirror of
https://github.com/serengil/deepface.git
synced 2025-06-08 20:45:22 +00:00
30 lines
1.0 KiB
Python
30 lines
1.0 KiB
Python
from abc import ABC
|
|
from typing import Any, Union, List, Tuple
|
|
import numpy as np
|
|
from deepface.commons import package_utils
|
|
|
|
tf_version = package_utils.get_tf_major_version()
|
|
if tf_version == 2:
|
|
from tensorflow.keras.models import Model
|
|
else:
|
|
from keras.models import Model
|
|
|
|
# Notice that all facial recognition models must be inherited from this class
|
|
|
|
# pylint: disable=too-few-public-methods
|
|
class FacialRecognition(ABC):
|
|
model: Union[Model, Any]
|
|
model_name: str
|
|
input_shape: Tuple[int, int]
|
|
output_shape: int
|
|
|
|
def forward(self, img: np.ndarray) -> List[float]:
|
|
if not isinstance(self.model, Model):
|
|
raise ValueError(
|
|
"You must overwrite forward method if it is not a keras model,"
|
|
f"but {self.model_name} not overwritten!"
|
|
)
|
|
# model.predict causes memory issue when it is called in a for loop
|
|
# embedding = model.predict(img, verbose=0)[0].tolist()
|
|
return self.model(img, training=False).numpy()[0].tolist()
|