Refine type hints as per PR review

This commit is contained in:
RazaProdigy 2023-12-18 13:18:57 +04:00
parent cfff58792a
commit 90a0282ed9

View File

@ -5,7 +5,9 @@ import warnings
import time import time
import pickle import pickle
import logging import logging
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple, Union
from keras.engine.functional import Functional
from deepface.basemodels import DlibResNet, SFace
# 3rd party dependencies # 3rd party dependencies
import numpy as np import numpy as np
@ -46,7 +48,7 @@ if tf_version == 2:
# ----------------------------------- # -----------------------------------
def build_model(model_name: str) -> Any: def build_model(model_name: str) -> Union[Functional, DlibResNet.DlibResNet, SFace.SFaceModel]:
""" """
This function builds a deepface model This function builds a deepface model
Parameters: Parameters:
@ -92,8 +94,8 @@ def build_model(model_name: str) -> Any:
def verify( def verify(
img1_path: str, img1_path: Union[str, np.ndarray],
img2_path: str, img2_path: Union[str, np.ndarray],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
detector_backend: str = "opencv", detector_backend: str = "opencv",
distance_metric: str = "cosine", distance_metric: str = "cosine",
@ -232,7 +234,7 @@ def verify(
def analyze( def analyze(
img_path: str, img_path: Union[str, np.ndarray],,
actions: Tuple[str, ...] = ("emotion", "age", "gender", "race"), actions: Tuple[str, ...] = ("emotion", "age", "gender", "race"),
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
@ -410,7 +412,7 @@ def analyze(
def find( def find(
img_path: str, img_path: Union[str, np.ndarray],
db_path : str, db_path : str,
model_name : str ="VGG-Face", model_name : str ="VGG-Face",
distance_metric : str ="cosine", distance_metric : str ="cosine",
@ -651,7 +653,7 @@ def find(
def represent( def represent(
img_path: str, img_path: Union[str, np.ndarray],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
@ -817,7 +819,7 @@ def stream(
def extract_faces( def extract_faces(
img_path: str, img_path: Union[str, np.ndarray],
target_size: Tuple[int, int] = (224, 224), target_size: Tuple[int, int] = (224, 224),
detector_backend: str = "opencv", detector_backend: str = "opencv",
enforce_detection: bool = True, enforce_detection: bool = True,