Merge branch 'serengil:master' into master

This commit is contained in:
Tyas 2024-10-10 16:50:58 +07:00 committed by GitHub
commit 49c041945a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 361 additions and 211 deletions

View File

@ -323,18 +323,18 @@ def find(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False). anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns: Returns:
results (List[pd.DataFrame] or List[List[Dict[str, Any]]]): results (List[pd.DataFrame] or List[List[Dict[str, Any]]]):
A list of pandas dataframes (if `batched=False`) or A list of pandas dataframes (if `batched=False`) or
a list of dicts (if `batched=True`). a list of dicts (if `batched=True`).
Each dataframe or dict corresponds to the identity information for Each dataframe or dict corresponds to the identity information for
an individual detected in the source image. an individual detected in the source image.
Note: If you have a large database and/or a source photo with many faces, Note: If you have a large database and/or a source photo with many faces,
use `batched=True`, as it is optimized for large batch processing. use `batched=True`, as it is optimized for large batch processing.
Please pay attention that when using `batched=True`, the function returns Please pay attention that when using `batched=True`, the function returns
a list of dicts (not a list of DataFrames), a list of dicts (not a list of DataFrames),
but with the same keys as the columns in the DataFrame. but with the same keys as the columns in the DataFrame.
The DataFrame columns or dict keys include: The DataFrame columns or dict keys include:
- 'identity': Identity label of the detected individual. - 'identity': Identity label of the detected individual.
@ -364,7 +364,7 @@ def find(
silent=silent, silent=silent,
refresh_database=refresh_database, refresh_database=refresh_database,
anti_spoofing=anti_spoofing, anti_spoofing=anti_spoofing,
batched=batched batched=batched,
) )

View File

@ -11,6 +11,7 @@ import gdown
from deepface.commons import folder_utils, package_utils from deepface.commons import folder_utils, package_utils
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
tf_version = package_utils.get_tf_major_version() tf_version = package_utils.get_tf_major_version()
if tf_version == 1: if tf_version == 1:
from keras.models import Sequential from keras.models import Sequential
@ -19,6 +20,8 @@ else:
logger = Logger() logger = Logger()
# pylint: disable=line-too-long, use-maxsplit-arg
ALLOWED_COMPRESS_TYPES = ["zip", "bz2"] ALLOWED_COMPRESS_TYPES = ["zip", "bz2"]
@ -95,3 +98,98 @@ def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
"and copying it to the target folder." "and copying it to the target folder."
) from err ) from err
return model return model
def download_all_models_in_one_shot() -> None:
"""
Download all model weights in one shot
"""
# import model weights from module here to avoid circular import issue
from deepface.models.facial_recognition.VGGFace import WEIGHTS_URL as VGGFACE_WEIGHTS
from deepface.models.facial_recognition.Facenet import FACENET128_WEIGHTS, FACENET512_WEIGHTS
from deepface.models.facial_recognition.OpenFace import WEIGHTS_URL as OPENFACE_WEIGHTS
from deepface.models.facial_recognition.FbDeepFace import WEIGHTS_URL as FBDEEPFACE_WEIGHTS
from deepface.models.facial_recognition.ArcFace import WEIGHTS_URL as ARCFACE_WEIGHTS
from deepface.models.facial_recognition.DeepID import WEIGHTS_URL as DEEPID_WEIGHTS
from deepface.models.facial_recognition.SFace import WEIGHTS_URL as SFACE_WEIGHTS
from deepface.models.facial_recognition.GhostFaceNet import WEIGHTS_URL as GHOSTFACENET_WEIGHTS
from deepface.models.facial_recognition.Dlib import WEIGHT_URL as DLIB_FR_WEIGHTS
from deepface.models.demography.Age import WEIGHTS_URL as AGE_WEIGHTS
from deepface.models.demography.Gender import WEIGHTS_URL as GENDER_WEIGHTS
from deepface.models.demography.Race import WEIGHTS_URL as RACE_WEIGHTS
from deepface.models.demography.Emotion import WEIGHTS_URL as EMOTION_WEIGHTS
from deepface.models.spoofing.FasNet import (
FIRST_WEIGHTS_URL as FASNET_1ST_WEIGHTS,
SECOND_WEIGHTS_URL as FASNET_2ND_WEIGHTS,
)
from deepface.models.face_detection.Ssd import (
MODEL_URL as SSD_MODEL,
WEIGHTS_URL as SSD_WEIGHTS,
)
from deepface.models.face_detection.Yolo import (
WEIGHT_URL as YOLOV8_WEIGHTS,
WEIGHT_NAME as YOLOV8_WEIGHT_NAME,
)
from deepface.models.face_detection.YuNet import WEIGHTS_URL as YUNET_WEIGHTS
from deepface.models.face_detection.Dlib import WEIGHTS_URL as DLIB_FD_WEIGHTS
from deepface.models.face_detection.CenterFace import WEIGHTS_URL as CENTERFACE_WEIGHTS
WEIGHTS = [
# facial recognition
VGGFACE_WEIGHTS,
FACENET128_WEIGHTS,
FACENET512_WEIGHTS,
OPENFACE_WEIGHTS,
FBDEEPFACE_WEIGHTS,
ARCFACE_WEIGHTS,
DEEPID_WEIGHTS,
SFACE_WEIGHTS,
{
"filename": "ghostfacenet_v1.h5",
"url": GHOSTFACENET_WEIGHTS,
},
DLIB_FR_WEIGHTS,
# demography
AGE_WEIGHTS,
GENDER_WEIGHTS,
RACE_WEIGHTS,
EMOTION_WEIGHTS,
# spoofing
FASNET_1ST_WEIGHTS,
FASNET_2ND_WEIGHTS,
# face detection
SSD_MODEL,
SSD_WEIGHTS,
{
"filename": YOLOV8_WEIGHT_NAME,
"url": YOLOV8_WEIGHTS,
},
YUNET_WEIGHTS,
DLIB_FD_WEIGHTS,
CENTERFACE_WEIGHTS,
]
for i in WEIGHTS:
if isinstance(i, str):
url = i
filename = i.split("/")[-1]
compress_type = None
# if compressed file will be downloaded, get rid of its extension
if filename.endswith(tuple(ALLOWED_COMPRESS_TYPES)):
for ext in ALLOWED_COMPRESS_TYPES:
compress_type = ext
if filename.endswith(f".{ext}"):
filename = filename[: -(len(ext) + 1)]
break
elif isinstance(i, dict):
filename = i["filename"]
url = i["url"]
else:
raise ValueError("unimplemented scenario")
logger.info(
f"Downloading {url} to ~/.deepface/weights/{filename} with {compress_type} compression"
)
download_weights_if_necessary(
file_name=filename, source_url=url, compress_type=compress_type
)

View File

@ -6,7 +6,7 @@ import numpy as np
# Notice that all facial detector models must be inherited from this class # Notice that all facial detector models must be inherited from this class
# pylint: disable=unnecessary-pass, too-few-public-methods # pylint: disable=unnecessary-pass, too-few-public-methods, too-many-instance-attributes
class Detector(ABC): class Detector(ABC):
@abstractmethod @abstractmethod
def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]: def detect_faces(self, img: np.ndarray) -> List["FacialAreaRegion"]:
@ -45,6 +45,7 @@ class FacialAreaRegion:
confidence (float, optional): Confidence score associated with the face detection. confidence (float, optional): Confidence score associated with the face detection.
Default is None. Default is None.
""" """
x: int x: int
y: int y: int
w: int w: int
@ -52,6 +53,9 @@ class FacialAreaRegion:
left_eye: Optional[Tuple[int, int]] = None left_eye: Optional[Tuple[int, int]] = None
right_eye: Optional[Tuple[int, int]] = None right_eye: Optional[Tuple[int, int]] = None
confidence: Optional[float] = None confidence: Optional[float] = None
nose: Optional[Tuple[int, int]] = None
mouth_right: Optional[Tuple[int, int]] = None
mouth_left: Optional[Tuple[int, int]] = None
@dataclass @dataclass
@ -63,7 +67,8 @@ class DetectedFace:
img (np.ndarray): detected face image as numpy array img (np.ndarray): detected face image as numpy array
facial_area (FacialAreaRegion): detected face's metadata (e.g. bounding box) facial_area (FacialAreaRegion): detected face's metadata (e.g. bounding box)
confidence (float): confidence score for face detection confidence (float): confidence score for face detection
""" """
img: np.ndarray img: np.ndarray
facial_area: FacialAreaRegion facial_area: FacialAreaRegion
confidence: float confidence: float

View File

@ -23,6 +23,10 @@ else:
# ---------------------------------------- # ----------------------------------------
WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5"
)
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class ApparentAgeClient(Demography): class ApparentAgeClient(Demography):
""" """
@ -41,7 +45,7 @@ class ApparentAgeClient(Demography):
def load_model( def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5", url=WEIGHTS_URL,
) -> Model: ) -> Model:
""" """
Construct age model, download its weights and load Construct age model, download its weights and load
@ -70,12 +74,11 @@ def load_model(
file_name="age_model_weights.h5", source_url=url file_name="age_model_weights.h5", source_url=url
) )
age_model = weight_utils.load_model_weights( age_model = weight_utils.load_model_weights(model=age_model, weight_file=weight_file)
model=age_model, weight_file=weight_file
)
return age_model return age_model
def find_apparent_age(age_predictions: np.ndarray) -> np.float64: def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
""" """
Find apparent age prediction from a given probas of ages Find apparent age prediction from a given probas of ages

View File

@ -7,11 +7,6 @@ from deepface.commons import package_utils, weight_utils
from deepface.models.Demography import Demography from deepface.models.Demography import Demography
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
logger = Logger()
# -------------------------------------------
# pylint: disable=line-too-long
# -------------------------------------------
# dependency configuration # dependency configuration
tf_version = package_utils.get_tf_major_version() tf_version = package_utils.get_tf_major_version()
@ -28,12 +23,17 @@ else:
Dense, Dense,
Dropout, Dropout,
) )
# -------------------------------------------
# Labels for the emotions that can be detected by the model. # Labels for the emotions that can be detected by the model.
labels = ["angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"] labels = ["angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"]
# pylint: disable=too-few-public-methods logger = Logger()
# pylint: disable=line-too-long, disable=too-few-public-methods
WEIGHTS_URL = "https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5"
class EmotionClient(Demography): class EmotionClient(Demography):
""" """
Emotion model class Emotion model class
@ -56,7 +56,7 @@ class EmotionClient(Demography):
def load_model( def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5", url=WEIGHTS_URL,
) -> Sequential: ) -> Sequential:
""" """
Consruct emotion model, download and load weights Consruct emotion model, download and load weights
@ -96,8 +96,6 @@ def load_model(
file_name="facial_expression_model_weights.h5", source_url=url file_name="facial_expression_model_weights.h5", source_url=url
) )
model = weight_utils.load_model_weights( model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
model=model, weight_file=weight_file
)
return model return model

View File

@ -21,7 +21,8 @@ if tf_version == 1:
else: else:
from tensorflow.keras.models import Model, Sequential from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution2D, Flatten, Activation from tensorflow.keras.layers import Convolution2D, Flatten, Activation
# -------------------------------------
WEIGHTS_URL="https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5"
# Labels for the genders that can be detected by the model. # Labels for the genders that can be detected by the model.
labels = ["Woman", "Man"] labels = ["Woman", "Man"]
@ -43,7 +44,7 @@ class GenderClient(Demography):
def load_model( def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5", url=WEIGHTS_URL,
) -> Model: ) -> Model:
""" """
Construct gender model, download its weights and load Construct gender model, download its weights and load

View File

@ -7,11 +7,8 @@ from deepface.commons import package_utils, weight_utils
from deepface.models.Demography import Demography from deepface.models.Demography import Demography
from deepface.commons.logger import Logger from deepface.commons.logger import Logger
logger = Logger()
# --------------------------
# pylint: disable=line-too-long # pylint: disable=line-too-long
# --------------------------
# dependency configurations # dependency configurations
tf_version = package_utils.get_tf_major_version() tf_version = package_utils.get_tf_major_version()
@ -21,10 +18,15 @@ if tf_version == 1:
else: else:
from tensorflow.keras.models import Model, Sequential from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution2D, Flatten, Activation from tensorflow.keras.layers import Convolution2D, Flatten, Activation
# --------------------------
WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5"
)
# Labels for the ethnic phenotypes that can be detected by the model. # Labels for the ethnic phenotypes that can be detected by the model.
labels = ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"] labels = ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"]
logger = Logger()
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class RaceClient(Demography): class RaceClient(Demography):
""" """
@ -42,7 +44,7 @@ class RaceClient(Demography):
def load_model( def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5", url=WEIGHTS_URL,
) -> Model: ) -> Model:
""" """
Construct race model, download its weights and load Construct race model, download its weights and load
@ -69,8 +71,6 @@ def load_model(
file_name="race_model_single_batch.h5", source_url=url file_name="race_model_single_batch.h5", source_url=url
) )
race_model = weight_utils.load_model_weights( race_model = weight_utils.load_model_weights(model=race_model, weight_file=weight_file)
model=race_model, weight_file=weight_file
)
return race_model return race_model

View File

@ -11,6 +11,7 @@ from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
WEIGHTS_URL="http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2"
class DlibClient(Detector): class DlibClient(Detector):
def __init__(self): def __init__(self):
@ -34,7 +35,7 @@ class DlibClient(Detector):
# check required file exists in the home/.deepface/weights folder # check required file exists in the home/.deepface/weights folder
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="shape_predictor_5_face_landmarks.dat", file_name="shape_predictor_5_face_landmarks.dat",
source_url="http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2", source_url=WEIGHTS_URL,
compress_type="bz2", compress_type="bz2",
) )

View File

@ -42,10 +42,19 @@ class RetinaFaceClient(Detector):
# retinaface sets left and right eyes with respect to the person # retinaface sets left and right eyes with respect to the person
left_eye = identity["landmarks"]["left_eye"] left_eye = identity["landmarks"]["left_eye"]
right_eye = identity["landmarks"]["right_eye"] right_eye = identity["landmarks"]["right_eye"]
nose = identity["landmarks"].get("nose")
mouth_right = identity["landmarks"].get("mouth_right")
mouth_left = identity["landmarks"].get("mouth_left")
# eyes are list of float, need to cast them tuple of int # eyes are list of float, need to cast them tuple of int
left_eye = tuple(int(i) for i in left_eye) left_eye = tuple(int(i) for i in left_eye)
right_eye = tuple(int(i) for i in right_eye) right_eye = tuple(int(i) for i in right_eye)
if nose is not None:
nose = tuple(int(i) for i in nose)
if mouth_right is not None:
mouth_right = tuple(int(i) for i in mouth_right)
if mouth_left is not None:
mouth_left = tuple(int(i) for i in mouth_left)
confidence = identity["score"] confidence = identity["score"]
@ -57,6 +66,9 @@ class RetinaFaceClient(Detector):
left_eye=left_eye, left_eye=left_eye,
right_eye=right_eye, right_eye=right_eye,
confidence=confidence, confidence=confidence,
nose=nose,
mouth_left=mouth_left,
mouth_right=mouth_right,
) )
resp.append(facial_area) resp.append(facial_area)

View File

@ -16,6 +16,9 @@ logger = Logger()
# pylint: disable=line-too-long, c-extension-no-member # pylint: disable=line-too-long, c-extension-no-member
MODEL_URL = "https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt"
WEIGHTS_URL = "https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel"
class SsdClient(Detector): class SsdClient(Detector):
def __init__(self): def __init__(self):
@ -31,13 +34,13 @@ class SsdClient(Detector):
# model structure # model structure
output_model = weight_utils.download_weights_if_necessary( output_model = weight_utils.download_weights_if_necessary(
file_name="deploy.prototxt", file_name="deploy.prototxt",
source_url="https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt", source_url=MODEL_URL,
) )
# pre-trained weights # pre-trained weights
output_weights = weight_utils.download_weights_if_necessary( output_weights = weight_utils.download_weights_if_necessary(
file_name="res10_300x300_ssd_iter_140000.caffemodel", file_name="res10_300x300_ssd_iter_140000.caffemodel",
source_url="https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel", source_url=WEIGHTS_URL,
) )
try: try:

View File

@ -12,7 +12,7 @@ from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
# Model's weights paths # Model's weights paths
PATH = ".deepface/weights/yolov8n-face.pt" WEIGHT_NAME = "yolov8n-face.pt"
# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB # Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB
WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb" WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb"
@ -39,7 +39,7 @@ class YoloClient(Detector):
) from e ) from e
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="yolov8n-face.pt", source_url=WEIGHT_URL file_name=WEIGHT_NAME, source_url=WEIGHT_URL
) )
# Return face_detector # Return face_detector

View File

@ -13,6 +13,9 @@ from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
# pylint:disable=line-too-long
WEIGHTS_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
class YuNetClient(Detector): class YuNetClient(Detector):
def __init__(self): def __init__(self):
@ -41,7 +44,7 @@ class YuNetClient(Detector):
# pylint: disable=C0301 # pylint: disable=C0301
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="face_detection_yunet_2023mar.onnx", file_name="face_detection_yunet_2023mar.onnx",
source_url="https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx", source_url=WEIGHTS_URL,
) )
try: try:

View File

@ -42,6 +42,8 @@ else:
Dense, Dense,
) )
WEIGHTS_URL="https://github.com/serengil/deepface_models/releases/download/v1.0/arcface_weights.h5"
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class ArcFaceClient(FacialRecognition): class ArcFaceClient(FacialRecognition):
""" """
@ -56,7 +58,7 @@ class ArcFaceClient(FacialRecognition):
def load_model( def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/arcface_weights.h5", url=WEIGHTS_URL,
) -> Model: ) -> Model:
""" """
Construct ArcFace model, download its weights and load Construct ArcFace model, download its weights and load

View File

@ -34,8 +34,7 @@ else:
# pylint: disable=line-too-long # pylint: disable=line-too-long
WEIGHTS_URL="https://github.com/serengil/deepface_models/releases/download/v1.0/deepid_keras_weights.h5"
# -------------------------------------
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class DeepIdClient(FacialRecognition): class DeepIdClient(FacialRecognition):
@ -51,7 +50,7 @@ class DeepIdClient(FacialRecognition):
def load_model( def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/deepid_keras_weights.h5", url=WEIGHTS_URL,
) -> Model: ) -> Model:
""" """
Construct DeepId model, download its weights and load Construct DeepId model, download its weights and load

View File

@ -12,6 +12,7 @@ from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
WEIGHT_URL = "http://dlib.net/files/dlib_face_recognition_resnet_model_v1.dat.bz2"
class DlibClient(FacialRecognition): class DlibClient(FacialRecognition):
@ -70,7 +71,7 @@ class DlibResNet:
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="dlib_face_recognition_resnet_model_v1.dat", file_name="dlib_face_recognition_resnet_model_v1.dat",
source_url="http://dlib.net/files/dlib_face_recognition_resnet_model_v1.dat.bz2", source_url=WEIGHT_URL,
compress_type="bz2", compress_type="bz2",
) )

View File

@ -39,6 +39,14 @@ else:
from tensorflow.keras.layers import add from tensorflow.keras.layers import add
from tensorflow.keras import backend as K from tensorflow.keras import backend as K
# pylint:disable=line-too-long
FACENET128_WEIGHTS = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/facenet_weights.h5"
)
FACENET512_WEIGHTS = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5"
)
# -------------------------------- # --------------------------------
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
@ -1654,7 +1662,7 @@ def InceptionResNetV1(dimension: int = 128) -> Model:
def load_facenet128d_model( def load_facenet128d_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/facenet_weights.h5", url=FACENET128_WEIGHTS,
) -> Model: ) -> Model:
""" """
Construct FaceNet-128d model, download weights and then load weights Construct FaceNet-128d model, download weights and then load weights
@ -1668,15 +1676,13 @@ def load_facenet128d_model(
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="facenet_weights.h5", source_url=url file_name="facenet_weights.h5", source_url=url
) )
model = weight_utils.load_model_weights( model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
model=model, weight_file=weight_file
)
return model return model
def load_facenet512d_model( def load_facenet512d_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5", url=FACENET512_WEIGHTS,
) -> Model: ) -> Model:
""" """
Construct FaceNet-512d model, download its weights and load Construct FaceNet-512d model, download its weights and load
@ -1689,8 +1695,6 @@ def load_facenet512d_model(
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="facenet512_weights.h5", source_url=url file_name="facenet512_weights.h5", source_url=url
) )
model = weight_utils.load_model_weights( model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
model=model, weight_file=weight_file
)
return model return model

View File

@ -30,9 +30,9 @@ else:
Dropout, Dropout,
) )
# -------------------------------------
# pylint: disable=line-too-long, too-few-public-methods # pylint: disable=line-too-long, too-few-public-methods
WEIGHTS_URL="https://github.com/swghosh/DeepFace/releases/download/weights-vggface2-2d-aligned/VGGFace2_DeepFace_weights_val-0.9034.h5.zip"
class DeepFaceClient(FacialRecognition): class DeepFaceClient(FacialRecognition):
""" """
Fb's DeepFace model class Fb's DeepFace model class
@ -54,7 +54,7 @@ class DeepFaceClient(FacialRecognition):
def load_model( def load_model(
url="https://github.com/swghosh/DeepFace/releases/download/weights-vggface2-2d-aligned/VGGFace2_DeepFace_weights_val-0.9034.h5.zip", url=WEIGHTS_URL,
) -> Model: ) -> Model:
""" """
Construct DeepFace model, download its weights and load Construct DeepFace model, download its weights and load

View File

@ -48,7 +48,7 @@ else:
# pylint: disable=line-too-long, too-few-public-methods, no-else-return, unsubscriptable-object, comparison-with-callable # pylint: disable=line-too-long, too-few-public-methods, no-else-return, unsubscriptable-object, comparison-with-callable
PRETRAINED_WEIGHTS = "https://github.com/HamadYA/GhostFaceNets/releases/download/v1.2/GhostFaceNet_W1.3_S1_ArcFace.h5" WEIGHTS_URL = "https://github.com/HamadYA/GhostFaceNets/releases/download/v1.2/GhostFaceNet_W1.3_S1_ArcFace.h5"
class GhostFaceNetClient(FacialRecognition): class GhostFaceNetClient(FacialRecognition):
@ -71,12 +71,10 @@ def load_model():
model = GhostFaceNetV1() model = GhostFaceNetV1()
weight_file = weight_utils.download_weights_if_necessary( weight_file = weight_utils.download_weights_if_necessary(
file_name="ghostfacenet_v1.h5", source_url=PRETRAINED_WEIGHTS file_name="ghostfacenet_v1.h5", source_url=WEIGHTS_URL
) )
model = weight_utils.load_model_weights( model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
model=model, weight_file=weight_file
)
return model return model

View File

@ -24,6 +24,8 @@ else:
# pylint: disable=unnecessary-lambda # pylint: disable=unnecessary-lambda
WEIGHTS_URL="https://github.com/serengil/deepface_models/releases/download/v1.0/openface_weights.h5"
# --------------------------------------- # ---------------------------------------
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
@ -40,7 +42,7 @@ class OpenFaceClient(FacialRecognition):
def load_model( def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/openface_weights.h5", url=WEIGHTS_URL,
) -> Model: ) -> Model:
""" """
Consturct OpenFace model, download its weights and load Consturct OpenFace model, download its weights and load

View File

@ -13,6 +13,7 @@ from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
# pylint: disable=line-too-long, too-few-public-methods # pylint: disable=line-too-long, too-few-public-methods
WEIGHTS_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx"
class SFaceClient(FacialRecognition): class SFaceClient(FacialRecognition):
@ -47,7 +48,7 @@ class SFaceClient(FacialRecognition):
def load_model( def load_model(
url="https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx", url=WEIGHTS_URL,
) -> Any: ) -> Any:
""" """
Construct SFace model, download its weights and load Construct SFace model, download its weights and load

View File

@ -38,6 +38,10 @@ else:
# --------------------------------------- # ---------------------------------------
WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/vgg_face_weights.h5"
)
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class VggFaceClient(FacialRecognition): class VggFaceClient(FacialRecognition):
""" """
@ -126,7 +130,7 @@ def base_model() -> Sequential:
def load_model( def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/vgg_face_weights.h5", url=WEIGHTS_URL,
) -> Model: ) -> Model:
""" """
Final VGG-Face model being used for finding embeddings Final VGG-Face model being used for finding embeddings
@ -140,9 +144,7 @@ def load_model(
file_name="vgg_face_weights.h5", source_url=url file_name="vgg_face_weights.h5", source_url=url
) )
model = weight_utils.load_model_weights( model = weight_utils.load_model_weights(model=model, weight_file=weight_file)
model=model, weight_file=weight_file
)
# 2622d dimensional model # 2622d dimensional model
# vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output) # vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
@ -151,7 +153,6 @@ def load_model(
# - softmax causes underfitting # - softmax causes underfitting
# - added normalization layer to avoid underfitting with euclidean # - added normalization layer to avoid underfitting with euclidean
# as described here: https://github.com/serengil/deepface/issues/944 # as described here: https://github.com/serengil/deepface/issues/944
base_model_output = Sequential()
base_model_output = Flatten()(model.layers[-5].output) base_model_output = Flatten()(model.layers[-5].output)
# keras backend's l2 normalization layer troubles some gpu users (e.g. issue 957, 966) # keras backend's l2 normalization layer troubles some gpu users (e.g. issue 957, 966)
# base_model_output = Lambda(lambda x: K.l2_normalize(x, axis=1), name="norm_layer")( # base_model_output = Lambda(lambda x: K.l2_normalize(x, axis=1), name="norm_layer")(

View File

@ -12,6 +12,9 @@ from deepface.commons.logger import Logger
logger = Logger() logger = Logger()
# pylint: disable=line-too-long, too-few-public-methods, nested-min-max # pylint: disable=line-too-long, too-few-public-methods, nested-min-max
FIRST_WEIGHTS_URL="https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/2.7_80x80_MiniFASNetV2.pth"
SECOND_WEIGHTS_URL="https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/4_0_0_80x80_MiniFASNetV1SE.pth"
class Fasnet: class Fasnet:
""" """
Mini Face Anti Spoofing Net Library from repo: github.com/minivision-ai/Silent-Face-Anti-Spoofing Mini Face Anti Spoofing Net Library from repo: github.com/minivision-ai/Silent-Face-Anti-Spoofing
@ -35,12 +38,12 @@ class Fasnet:
# download pre-trained models if not installed yet # download pre-trained models if not installed yet
first_model_weight_file = weight_utils.download_weights_if_necessary( first_model_weight_file = weight_utils.download_weights_if_necessary(
file_name="2.7_80x80_MiniFASNetV2.pth", file_name="2.7_80x80_MiniFASNetV2.pth",
source_url="https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/2.7_80x80_MiniFASNetV2.pth", source_url=FIRST_WEIGHTS_URL,
) )
second_model_weight_file = weight_utils.download_weights_if_necessary( second_model_weight_file = weight_utils.download_weights_if_necessary(
file_name="4_0_0_80x80_MiniFASNetV1SE.pth", file_name="4_0_0_80x80_MiniFASNetV1SE.pth",
source_url="https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/4_0_0_80x80_MiniFASNetV1SE.pth", source_url=SECOND_WEIGHTS_URL,
) )
# guarantees Fasnet imported and torch installed # guarantees Fasnet imported and torch installed

View File

@ -148,16 +148,26 @@ def extract_faces(
w = min(width - x - 1, int(current_region.w)) w = min(width - x - 1, int(current_region.w))
h = min(height - y - 1, int(current_region.h)) h = min(height - y - 1, int(current_region.h))
facial_area = {
"x": x,
"y": y,
"w": w,
"h": h,
"left_eye": current_region.left_eye,
"right_eye": current_region.right_eye,
}
# optional nose, mouth_left and mouth_right fields are coming just for retinaface
if current_region.nose is not None:
facial_area["nose"] = current_region.nose
if current_region.mouth_left is not None:
facial_area["mouth_left"] = current_region.mouth_left
if current_region.mouth_right is not None:
facial_area["mouth_right"] = current_region.mouth_right
resp_obj = { resp_obj = {
"face": current_img, "face": current_img,
"facial_area": { "facial_area": facial_area,
"x": x,
"y": y,
"w": w,
"h": h,
"left_eye": current_region.left_eye,
"right_eye": current_region.right_eye,
},
"confidence": round(float(current_region.confidence or 0), 2), "confidence": round(float(current_region.confidence or 0), 2),
} }
@ -272,6 +282,9 @@ def expand_and_align_face(
left_eye = facial_area.left_eye left_eye = facial_area.left_eye
right_eye = facial_area.right_eye right_eye = facial_area.right_eye
confidence = facial_area.confidence confidence = facial_area.confidence
nose = facial_area.nose
mouth_left = facial_area.mouth_left
mouth_right = facial_area.mouth_right
if expand_percentage > 0: if expand_percentage > 0:
# Expand the facial region height and width by the provided percentage # Expand the facial region height and width by the provided percentage
@ -305,11 +318,26 @@ def expand_and_align_face(
left_eye = (left_eye[0] - width_border, left_eye[1] - height_border) left_eye = (left_eye[0] - width_border, left_eye[1] - height_border)
if right_eye is not None: if right_eye is not None:
right_eye = (right_eye[0] - width_border, right_eye[1] - height_border) right_eye = (right_eye[0] - width_border, right_eye[1] - height_border)
if nose is not None:
nose = (nose[0] - width_border, nose[1] - height_border)
if mouth_left is not None:
mouth_left = (mouth_left[0] - width_border, mouth_left[1] - height_border)
if mouth_right is not None:
mouth_right = (mouth_right[0] - width_border, mouth_right[1] - height_border)
return DetectedFace( return DetectedFace(
img=detected_face, img=detected_face,
facial_area=FacialAreaRegion( facial_area=FacialAreaRegion(
x=x, y=y, h=h, w=w, confidence=confidence, left_eye=left_eye, right_eye=right_eye x=x,
y=y,
h=h,
w=w,
confidence=confidence,
left_eye=left_eye,
right_eye=right_eye,
nose=nose,
mouth_left=mouth_left,
mouth_right=mouth_right,
), ),
confidence=confidence, confidence=confidence,
) )

View File

@ -78,18 +78,18 @@ def find(
Returns: Returns:
results (List[pd.DataFrame] or List[List[Dict[str, Any]]]): results (List[pd.DataFrame] or List[List[Dict[str, Any]]]):
A list of pandas dataframes (if `batched=False`) or A list of pandas dataframes (if `batched=False`) or
a list of dicts (if `batched=True`). a list of dicts (if `batched=True`).
Each dataframe or dict corresponds to the identity information for Each dataframe or dict corresponds to the identity information for
an individual detected in the source image. an individual detected in the source image.
Note: If you have a large database and/or a source photo with many faces, Note: If you have a large database and/or a source photo with many faces,
use `batched=True`, as it is optimized for large batch processing. use `batched=True`, as it is optimized for large batch processing.
Please pay attention that when using `batched=True`, the function returns Please pay attention that when using `batched=True`, the function returns
a list of dicts (not a list of DataFrames), a list of dicts (not a list of DataFrames),
but with the same keys as the columns in the DataFrame. but with the same keys as the columns in the DataFrame.
The DataFrame columns or dict keys include: The DataFrame columns or dict keys include:
- 'identity': Identity label of the detected individual. - 'identity': Identity label of the detected individual.
@ -266,7 +266,7 @@ def find(
align, align,
threshold, threshold,
normalization, normalization,
anti_spoofing anti_spoofing,
) )
df = pd.DataFrame(representations) df = pd.DataFrame(representations)
@ -441,6 +441,7 @@ def __find_bulk_embeddings(
return representations return representations
def find_batched( def find_batched(
representations: List[Dict[str, Any]], representations: List[Dict[str, Any]],
source_objs: List[Dict[str, Any]], source_objs: List[Dict[str, Any]],
@ -459,11 +460,11 @@ def find_batched(
The function uses batch processing for efficient computation of distances. The function uses batch processing for efficient computation of distances.
Args: Args:
representations (List[Dict[str, Any]]): representations (List[Dict[str, Any]]):
A list of dictionaries containing precomputed target embeddings and associated metadata. A list of dictionaries containing precomputed target embeddings and associated metadata.
Each dictionary should have at least the key `embedding`. Each dictionary should have at least the key `embedding`.
source_objs (List[Dict[str, Any]]): source_objs (List[Dict[str, Any]]):
A list of dictionaries representing the source images to compare against A list of dictionaries representing the source images to compare against
the target embeddings. Each dictionary should contain: the target embeddings. Each dictionary should contain:
- `face`: The image data or path to the source face image. - `face`: The image data or path to the source face image.
@ -471,7 +472,7 @@ def find_batched(
indicating the facial region. indicating the facial region.
- Optionally, `is_real`: A boolean indicating if the face is real - Optionally, `is_real`: A boolean indicating if the face is real
(used for anti-spoofing). (used for anti-spoofing).
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face).
@ -499,7 +500,7 @@ def find_batched(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False). anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns: Returns:
List[List[Dict[str, Any]]]: List[List[Dict[str, Any]]]:
A list where each element corresponds to a source face and A list where each element corresponds to a source face and
contains a list of dictionaries with matching faces. contains a list of dictionaries with matching faces.
""" """
@ -508,27 +509,24 @@ def find_batched(
metadata = set() metadata = set()
for item in representations: for item in representations:
emb = item.get('embedding') emb = item.get("embedding")
if emb is not None: if emb is not None:
embeddings_list.append(emb) embeddings_list.append(emb)
valid_mask.append(True) valid_mask.append(True)
else: else:
embeddings_list.append(np.zeros_like(representations[0]['embedding'])) embeddings_list.append(np.zeros_like(representations[0]["embedding"]))
valid_mask.append(False) valid_mask.append(False)
metadata.update(item.keys()) metadata.update(item.keys())
# remove embedding key from other keys # remove embedding key from other keys
metadata.discard('embedding') metadata.discard("embedding")
metadata = list(metadata) metadata = list(metadata)
embeddings = np.array(embeddings_list) # (N, D) embeddings = np.array(embeddings_list) # (N, D)
valid_mask = np.array(valid_mask) # (N,) valid_mask = np.array(valid_mask) # (N,)
data = { data = {key: np.array([item.get(key, None) for item in representations]) for key in metadata}
key: np.array([item.get(key, None) for item in representations])
for key in metadata
}
target_embeddings = [] target_embeddings = []
source_regions = [] source_regions = []
@ -558,101 +556,46 @@ def find_batched(
target_threshold = threshold or verification.find_threshold(model_name, distance_metric) target_threshold = threshold or verification.find_threshold(model_name, distance_metric)
target_thresholds.append(target_threshold) target_thresholds.append(target_threshold)
target_embeddings = np.array(target_embeddings) # (M, D) target_embeddings = np.array(target_embeddings) # (M, D)
target_thresholds = np.array(target_thresholds) # (M,) target_thresholds = np.array(target_thresholds) # (M,)
source_regions_arr = { source_regions_arr = {
'source_x': np.array([region['x'] for region in source_regions]), "source_x": np.array([region["x"] for region in source_regions]),
'source_y': np.array([region['y'] for region in source_regions]), "source_y": np.array([region["y"] for region in source_regions]),
'source_w': np.array([region['w'] for region in source_regions]), "source_w": np.array([region["w"] for region in source_regions]),
'source_h': np.array([region['h'] for region in source_regions]), "source_h": np.array([region["h"] for region in source_regions]),
} }
def find_cosine_distance_batch( distances = verification.find_distance(embeddings, target_embeddings, distance_metric) # (M, N)
embeddings: np.ndarray, target_embeddings: np.ndarray
) -> np.ndarray:
"""
Find the cosine distances between batches of embeddings
Args:
embeddings (np.ndarray): array of shape (N, D)
target_embeddings (np.ndarray): array of shape (M, D)
Returns:
np.ndarray: distance matrix of shape (M, N)
"""
embeddings_norm = verification.l2_normalize(embeddings, axis=1)
target_embeddings_norm = verification.l2_normalize(target_embeddings, axis=1)
cosine_similarities = np.dot(target_embeddings_norm, embeddings_norm.T)
cosine_distances = 1 - cosine_similarities
return cosine_distances
def find_euclidean_distance_batch(
embeddings: np.ndarray, target_embeddings: np.ndarray
) -> np.ndarray:
"""
Find the Euclidean distances between batches of embeddings
Args:
embeddings (np.ndarray): array of shape (N, D)
target_embeddings (np.ndarray): array of shape (M, D)
Returns:
np.ndarray: distance matrix of shape (M, N)
"""
diff = embeddings[None, :, :] - target_embeddings[:, None, :] # (M, N, D)
distances = np.linalg.norm(diff, axis=2) # (M, N)
return distances
def find_distance_batch(
embeddings: np.ndarray, target_embeddings: np.ndarray, distance_metric: str,
) -> np.ndarray:
"""
Find pairwise distances between batches of embeddings using the specified distance metric
Args:
embeddings (np.ndarray): array of shape (N, D)
target_embeddings (np.ndarray): array of shape (M, D)
distance_metric (str): distance metric ('cosine', 'euclidean', 'euclidean_l2')
Returns:
np.ndarray: distance matrix of shape (M, N)
"""
if distance_metric == "cosine":
distances = find_cosine_distance_batch(embeddings, target_embeddings)
elif distance_metric == "euclidean":
distances = find_euclidean_distance_batch(embeddings, target_embeddings)
elif distance_metric == "euclidean_l2":
embeddings_norm = verification.l2_normalize(embeddings, axis=1)
target_embeddings_norm = verification.l2_normalize(target_embeddings, axis=1)
distances = find_euclidean_distance_batch(embeddings_norm, target_embeddings_norm)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
return np.round(distances, 6)
distances = find_distance_batch(embeddings, target_embeddings, distance_metric) # (M, N)
distances[:, ~valid_mask] = np.inf distances[:, ~valid_mask] = np.inf
resp_obj = [] resp_obj = []
for i in range(len(target_embeddings)): for i in range(len(target_embeddings)):
target_distances = distances[i] # (N,) target_distances = distances[i] # (N,)
target_threshold = target_thresholds[i] target_threshold = target_thresholds[i]
N = embeddings.shape[0] N = embeddings.shape[0]
result_data = dict(data) result_data = dict(data)
result_data.update({ result_data.update(
'source_x': np.full(N, source_regions_arr['source_x'][i]), {
'source_y': np.full(N, source_regions_arr['source_y'][i]), "source_x": np.full(N, source_regions_arr["source_x"][i]),
'source_w': np.full(N, source_regions_arr['source_w'][i]), "source_y": np.full(N, source_regions_arr["source_y"][i]),
'source_h': np.full(N, source_regions_arr['source_h'][i]), "source_w": np.full(N, source_regions_arr["source_w"][i]),
'threshold': np.full(N, target_threshold), "source_h": np.full(N, source_regions_arr["source_h"][i]),
'distance': target_distances, "threshold": np.full(N, target_threshold),
}) "distance": target_distances,
}
)
mask = target_distances <= target_threshold mask = target_distances <= target_threshold
filtered_data = {key: value[mask] for key, value in result_data.items()} filtered_data = {key: value[mask] for key, value in result_data.items()}
sorted_indices = np.argsort(filtered_data['distance']) sorted_indices = np.argsort(filtered_data["distance"])
sorted_data = {key: value[sorted_indices] for key, value in filtered_data.items()} sorted_data = {key: value[sorted_indices] for key, value in filtered_data.items()}
num_results = len(sorted_data['distance']) num_results = len(sorted_data["distance"])
result_dicts = [ result_dicts = [
{key: sorted_data[key][i] for key in sorted_data} {key: sorted_data[key][i] for key in sorted_data} for i in range(num_results)
for i in range(num_results)
] ]
resp_obj.append(result_dicts) resp_obj.append(result_dicts)
return resp_obj return resp_obj

View File

@ -263,45 +263,73 @@ def __extract_faces_and_embeddings(
def find_cosine_distance( def find_cosine_distance(
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list] source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
) -> np.float64: ) -> Union[np.float64, np.ndarray]:
""" """
Find cosine distance between two given vectors Find cosine distance between two given vectors or batches of vectors.
Args: Args:
source_representation (np.ndarray or list): 1st vector source_representation (np.ndarray or list): 1st vector or batch of vectors.
test_representation (np.ndarray or list): 2nd vector test_representation (np.ndarray or list): 2nd vector or batch of vectors.
Returns Returns
distance (np.float64): calculated cosine distance np.float64 or np.ndarray: Calculated cosine distance(s).
It returns a np.float64 for single embeddings and np.ndarray for batch embeddings.
""" """
if isinstance(source_representation, list): # Convert inputs to numpy arrays if necessary
source_representation = np.array(source_representation) source_representation = np.asarray(source_representation)
test_representation = np.asarray(test_representation)
if isinstance(test_representation, list): if source_representation.ndim == 1 and test_representation.ndim == 1:
test_representation = np.array(test_representation) # single embedding
dot_product = np.dot(source_representation, test_representation)
a = np.dot(source_representation, test_representation) source_norm = np.linalg.norm(source_representation)
b = np.linalg.norm(source_representation) test_norm = np.linalg.norm(test_representation)
c = np.linalg.norm(test_representation) distances = 1 - dot_product / (source_norm * test_norm)
return 1 - a / (b * c) elif source_representation.ndim == 2 and test_representation.ndim == 2:
# list of embeddings (batch)
source_normed = l2_normalize(source_representation, axis=1) # (N, D)
test_normed = l2_normalize(test_representation, axis=1) # (M, D)
cosine_similarities = np.dot(test_normed, source_normed.T) # (M, N)
distances = 1 - cosine_similarities
else:
raise ValueError(
f"Embeddings must be 1D or 2D, but received "
f"source shape: {source_representation.shape}, test shape: {test_representation.shape}"
)
return distances
def find_euclidean_distance( def find_euclidean_distance(
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list] source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
) -> np.float64: ) -> Union[np.float64, np.ndarray]:
""" """
Find euclidean distance between two given vectors Find Euclidean distance between two vectors or batches of vectors.
Args: Args:
source_representation (np.ndarray or list): 1st vector source_representation (np.ndarray or list): 1st vector or batch of vectors.
test_representation (np.ndarray or list): 2nd vector test_representation (np.ndarray or list): 2nd vector or batch of vectors.
Returns
distance (np.float64): calculated euclidean distance Returns:
np.float64 or np.ndarray: Euclidean distance(s).
Returns a np.float64 for single embeddings and np.ndarray for batch embeddings.
""" """
if isinstance(source_representation, list): # Convert inputs to numpy arrays if necessary
source_representation = np.array(source_representation) source_representation = np.asarray(source_representation)
test_representation = np.asarray(test_representation)
if isinstance(test_representation, list): # Single embedding case (1D arrays)
test_representation = np.array(test_representation) if source_representation.ndim == 1 and test_representation.ndim == 1:
distances = np.linalg.norm(source_representation - test_representation)
return np.linalg.norm(source_representation - test_representation) # Batch embeddings case (2D arrays)
elif source_representation.ndim == 2 and test_representation.ndim == 2:
diff = (
source_representation[None, :, :] - test_representation[:, None, :]
) # (N, D) - (M, D) = (M, N, D)
distances = np.linalg.norm(diff, axis=2) # (M, N)
else:
raise ValueError(
f"Embeddings must be 1D or 2D, but received "
f"source shape: {source_representation.shape}, test shape: {test_representation.shape}"
)
return distances
def l2_normalize( def l2_normalize(
@ -315,8 +343,8 @@ def l2_normalize(
Returns: Returns:
np.ndarray: l2 normalized vector np.ndarray: l2 normalized vector
""" """
if isinstance(x, list): # Convert inputs to numpy arrays if necessary
x = np.array(x) x = np.asarray(x)
norm = np.linalg.norm(x, axis=axis, keepdims=True) norm = np.linalg.norm(x, axis=axis, keepdims=True)
return x / (norm + epsilon) return x / (norm + epsilon)
@ -325,23 +353,39 @@ def find_distance(
alpha_embedding: Union[np.ndarray, list], alpha_embedding: Union[np.ndarray, list],
beta_embedding: Union[np.ndarray, list], beta_embedding: Union[np.ndarray, list],
distance_metric: str, distance_metric: str,
) -> np.float64: ) -> Union[np.float64, np.ndarray]:
""" """
Wrapper to find distance between vectors according to the given distance metric Wrapper to find the distance between vectors based on the specified distance metric.
Args: Args:
source_representation (np.ndarray or list): 1st vector alpha_embedding (np.ndarray or list): 1st vector or batch of vectors.
test_representation (np.ndarray or list): 2nd vector beta_embedding (np.ndarray or list): 2nd vector or batch of vectors.
Returns distance_metric (str): The type of distance to compute
distance (np.float64): calculated cosine distance ('cosine', 'euclidean', or 'euclidean_l2').
Returns:
np.float64 or np.ndarray: The calculated distance(s).
""" """
# Convert inputs to numpy arrays if necessary
alpha_embedding = np.asarray(alpha_embedding)
beta_embedding = np.asarray(beta_embedding)
# Ensure that both embeddings are either 1D or 2D
if alpha_embedding.ndim != beta_embedding.ndim or alpha_embedding.ndim not in (1, 2):
raise ValueError(
f"Both embeddings must be either 1D or 2D, but received "
f"alpha shape: {alpha_embedding.shape}, beta shape: {beta_embedding.shape}"
)
if distance_metric == "cosine": if distance_metric == "cosine":
distance = find_cosine_distance(alpha_embedding, beta_embedding) distance = find_cosine_distance(alpha_embedding, beta_embedding)
elif distance_metric == "euclidean": elif distance_metric == "euclidean":
distance = find_euclidean_distance(alpha_embedding, beta_embedding) distance = find_euclidean_distance(alpha_embedding, beta_embedding)
elif distance_metric == "euclidean_l2": elif distance_metric == "euclidean_l2":
distance = find_euclidean_distance( axis = None if alpha_embedding.ndim == 1 else 1
l2_normalize(alpha_embedding), l2_normalize(beta_embedding) normalized_alpha = l2_normalize(alpha_embedding, axis=axis)
) normalized_beta = l2_normalize(beta_embedding, axis=axis)
distance = find_euclidean_distance(normalized_alpha, normalized_beta)
else: else:
raise ValueError("Invalid distance_metric passed - ", distance_metric) raise ValueError("Invalid distance_metric passed - ", distance_metric)
return np.round(distance, 6) return np.round(distance, 6)