interface functions moved to modules

This commit is contained in:
Sefik Ilkin Serengil 2024-01-13 22:21:41 +00:00
parent 1b40870a8f
commit 7f719b87bd
11 changed files with 957 additions and 531 deletions

13
.github/pull_request_template.md vendored Normal file
View File

@ -0,0 +1,13 @@
## Tickets
https://github.com/serengil/deepface/issues/XXX
### What has been done
With this PR, ...
## How to test
```shell
make lint && make test
```

View File

@ -1,35 +1,27 @@
# common dependencies
import os
from os import path
import warnings
import time
import pickle
import logging
from typing import Any, Dict, List, Tuple, Union
# 3rd party dependencies
import numpy as np
import pandas as pd
from tqdm import tqdm
import cv2
import tensorflow as tf
from deprecated import deprecated
# package dependencies
from deepface.basemodels import (
VGGFace,
OpenFace,
Facenet,
Facenet512,
FbDeepFace,
DeepID,
DlibWrapper,
ArcFace,
SFace,
)
from deepface.extendedmodels import Age, Gender, Race, Emotion
from deepface.commons import functions, realtime, distance as dst
from deepface.commons import functions
from deepface.commons.logger import Logger
from deepface.modules import (
modeling,
representation,
verification,
recognition,
demography,
detection,
realtime,
)
# pylint: disable=no-else-raise, simplifiable-if-expression
@ -60,38 +52,7 @@ def build_model(model_name: str) -> Union[Model, Any]:
Returns:
built deepface model ( (tf.)keras.models.Model )
"""
# singleton design pattern
global model_obj
models = {
"VGG-Face": VGGFace.loadModel,
"OpenFace": OpenFace.loadModel,
"Facenet": Facenet.loadModel,
"Facenet512": Facenet512.loadModel,
"DeepFace": FbDeepFace.loadModel,
"DeepID": DeepID.loadModel,
"Dlib": DlibWrapper.loadModel,
"ArcFace": ArcFace.loadModel,
"SFace": SFace.load_model,
"Emotion": Emotion.loadModel,
"Age": Age.loadModel,
"Gender": Gender.loadModel,
"Race": Race.loadModel,
}
if not "model_obj" in globals():
model_obj = {}
if not model_name in model_obj:
model = models.get(model_name)
if model:
model = model()
model_obj[model_name] = model
else:
raise ValueError(f"Invalid model_name passed - {model_name}")
return model_obj[model_name]
return modeling.build_model(model_name=model_name)
def verify(
@ -149,90 +110,17 @@ def verify(
"""
tic = time.time()
# --------------------------------
target_size = functions.find_target_size(model_name=model_name)
# img pairs might have many faces
img1_objs = functions.extract_faces(
img=img1_path,
target_size=target_size,
return verification.verify(
img1_path=img1_path,
img2_path=img2_path,
model_name=model_name,
detector_backend=detector_backend,
grayscale=False,
distance_metric=distance_metric,
enforce_detection=enforce_detection,
align=align,
normalization=normalization,
)
img2_objs = functions.extract_faces(
img=img2_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
)
# --------------------------------
distances = []
regions = []
# now we will find the face pair with minimum distance
for img1_content, img1_region, _ in img1_objs:
for img2_content, img2_region, _ in img2_objs:
img1_embedding_obj = represent(
img_path=img1_content,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
img2_embedding_obj = represent(
img_path=img2_content,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
img1_representation = img1_embedding_obj[0]["embedding"]
img2_representation = img2_embedding_obj[0]["embedding"]
if distance_metric == "cosine":
distance = dst.findCosineDistance(img1_representation, img2_representation)
elif distance_metric == "euclidean":
distance = dst.findEuclideanDistance(img1_representation, img2_representation)
elif distance_metric == "euclidean_l2":
distance = dst.findEuclideanDistance(
dst.l2_normalize(img1_representation), dst.l2_normalize(img2_representation)
)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
distances.append(distance)
regions.append((img1_region, img2_region))
# -------------------------------
threshold = dst.findThreshold(model_name, distance_metric)
distance = min(distances) # best distance
facial_areas = regions[np.argmin(distances)]
toc = time.time()
resp_obj = {
"verified": True if distance <= threshold else False,
"distance": distance,
"threshold": threshold,
"model": model_name,
"detector_backend": detector_backend,
"similarity_metric": distance_metric,
"facial_areas": {"img1": facial_areas[0], "img2": facial_areas[1]},
"time": round(toc - tic, 2),
}
return resp_obj
def analyze(
img_path: Union[str, np.ndarray],
@ -301,116 +189,15 @@ def analyze(
}
]
"""
# ---------------------------------
# validate actions
if isinstance(actions, str):
actions = (actions,)
# check if actions is not an iterable or empty.
if not hasattr(actions, "__getitem__") or not actions:
raise ValueError("`actions` must be a list of strings.")
actions = list(actions)
# For each action, check if it is valid
for action in actions:
if action not in ("emotion", "age", "gender", "race"):
raise ValueError(
f"Invalid action passed ({repr(action)})). "
"Valid actions are `emotion`, `age`, `gender`, `race`."
)
# ---------------------------------
# build models
models = {}
if "emotion" in actions:
models["emotion"] = build_model("Emotion")
if "age" in actions:
models["age"] = build_model("Age")
if "gender" in actions:
models["gender"] = build_model("Gender")
if "race" in actions:
models["race"] = build_model("Race")
# ---------------------------------
resp_objects = []
img_objs = functions.extract_faces(
img=img_path,
target_size=(224, 224),
detector_backend=detector_backend,
grayscale=False,
return demography.analyze(
img_path=img_path,
actions=actions,
enforce_detection=enforce_detection,
detector_backend=detector_backend,
align=align,
silent=silent,
)
for img_content, img_region, img_confidence in img_objs:
if img_content.shape[0] > 0 and img_content.shape[1] > 0:
obj = {}
# facial attribute analysis
pbar = tqdm(
range(0, len(actions)),
desc="Finding actions",
disable=silent if len(actions) > 1 else True,
)
for index in pbar:
action = actions[index]
pbar.set_description(f"Action: {action}")
if action == "emotion":
img_gray = cv2.cvtColor(img_content[0], cv2.COLOR_BGR2GRAY)
img_gray = cv2.resize(img_gray, (48, 48))
img_gray = np.expand_dims(img_gray, axis=0)
emotion_predictions = models["emotion"].predict(img_gray, verbose=0)[0, :]
sum_of_predictions = emotion_predictions.sum()
obj["emotion"] = {}
for i, emotion_label in enumerate(Emotion.labels):
emotion_prediction = 100 * emotion_predictions[i] / sum_of_predictions
obj["emotion"][emotion_label] = emotion_prediction
obj["dominant_emotion"] = Emotion.labels[np.argmax(emotion_predictions)]
elif action == "age":
age_predictions = models["age"].predict(img_content, verbose=0)[0, :]
apparent_age = Age.findApparentAge(age_predictions)
# int cast is for exception - object of type 'float32' is not JSON serializable
obj["age"] = int(apparent_age)
elif action == "gender":
gender_predictions = models["gender"].predict(img_content, verbose=0)[0, :]
obj["gender"] = {}
for i, gender_label in enumerate(Gender.labels):
gender_prediction = 100 * gender_predictions[i]
obj["gender"][gender_label] = gender_prediction
obj["dominant_gender"] = Gender.labels[np.argmax(gender_predictions)]
elif action == "race":
race_predictions = models["race"].predict(img_content, verbose=0)[0, :]
sum_of_predictions = race_predictions.sum()
obj["race"] = {}
for i, race_label in enumerate(Race.labels):
race_prediction = 100 * race_predictions[i] / sum_of_predictions
obj["race"][race_label] = race_prediction
obj["dominant_race"] = Race.labels[np.argmax(race_predictions)]
# -----------------------------
# mention facial areas
obj["region"] = img_region
# include image confidence
obj["face_confidence"] = img_confidence
resp_objects.append(obj)
return resp_objects
def find(
img_path: Union[str, np.ndarray],
@ -457,211 +244,18 @@ def find(
This function returns list of pandas data frame. Each item of the list corresponding to
an identity in the img_path.
"""
tic = time.time()
# -------------------------------
if os.path.isdir(db_path) is not True:
raise ValueError("Passed db_path does not exist!")
target_size = functions.find_target_size(model_name=model_name)
# ---------------------------------------
file_name = f"representations_{model_name}.pkl"
file_name = file_name.replace("-", "_").lower()
df_cols = [
"identity",
f"{model_name}_representation",
"target_x",
"target_y",
"target_w",
"target_h",
]
if path.exists(db_path + "/" + file_name):
if not silent:
logger.warn(
f"Representations for images in {db_path} folder were previously stored"
f" in {file_name}. If you added new instances after the creation, then please "
"delete this file and call find function again. It will create it again."
)
with open(f"{db_path}/{file_name}", "rb") as f:
representations = pickle.load(f)
if len(representations) > 0 and len(representations[0]) != len(df_cols):
raise ValueError(
f"Seems existing {db_path}/{file_name} is out-of-the-date."
"Delete it and re-run."
)
if not silent:
logger.info(f"There are {len(representations)} representations found in {file_name}")
else: # create representation.pkl from scratch
employees = []
for r, _, f in os.walk(db_path):
for file in f:
if (
(".jpg" in file.lower())
or (".jpeg" in file.lower())
or (".png" in file.lower())
):
exact_path = r + "/" + file
employees.append(exact_path)
if len(employees) == 0:
raise ValueError(
"There is no image in ",
db_path,
" folder! Validate .jpg or .png files exist in this path.",
)
# ------------------------
# find representations for db images
representations = []
# for employee in employees:
pbar = tqdm(
range(0, len(employees)),
desc="Finding representations",
disable=silent,
)
for index in pbar:
employee = employees[index]
img_objs = functions.extract_faces(
img=employee,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
)
for img_content, img_region, _ in img_objs:
embedding_obj = represent(
img_path=img_content,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
img_representation = embedding_obj[0]["embedding"]
instance = []
instance.append(employee)
instance.append(img_representation)
instance.append(img_region["x"])
instance.append(img_region["y"])
instance.append(img_region["w"])
instance.append(img_region["h"])
representations.append(instance)
# -------------------------------
with open(f"{db_path}/{file_name}", "wb") as f:
pickle.dump(representations, f)
if not silent:
logger.info(
f"Representations stored in {db_path}/{file_name} file."
+ "Please delete this file when you add new identities in your database."
)
# ----------------------------
# now, we got representations for facial database
df = pd.DataFrame(
representations,
columns=df_cols,
)
# img path might have more than once face
source_objs = functions.extract_faces(
img=img_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
return recognition.find(
img_path=img_path,
db_path=db_path,
model_name=model_name,
distance_metric=distance_metric,
enforce_detection=enforce_detection,
detector_backend=detector_backend,
align=align,
normalization=normalization,
silent=silent,
)
resp_obj = []
for source_img, source_region, _ in source_objs:
target_embedding_obj = represent(
img_path=source_img,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
target_representation = target_embedding_obj[0]["embedding"]
result_df = df.copy() # df will be filtered in each img
result_df["source_x"] = source_region["x"]
result_df["source_y"] = source_region["y"]
result_df["source_w"] = source_region["w"]
result_df["source_h"] = source_region["h"]
distances = []
for index, instance in df.iterrows():
source_representation = instance[f"{model_name}_representation"]
target_dims = len(list(target_representation))
source_dims = len(list(source_representation))
if target_dims != source_dims:
raise ValueError(
"Source and target embeddings must have same dimensions but "
+ f"{target_dims}:{source_dims}. Model structure may change"
+ " after pickle created. Delete the {file_name} and re-run."
)
if distance_metric == "cosine":
distance = dst.findCosineDistance(source_representation, target_representation)
elif distance_metric == "euclidean":
distance = dst.findEuclideanDistance(source_representation, target_representation)
elif distance_metric == "euclidean_l2":
distance = dst.findEuclideanDistance(
dst.l2_normalize(source_representation),
dst.l2_normalize(target_representation),
)
else:
raise ValueError(f"invalid distance metric passes - {distance_metric}")
distances.append(distance)
# ---------------------------
result_df[f"{model_name}_{distance_metric}"] = distances
threshold = dst.findThreshold(model_name, distance_metric)
result_df = result_df.drop(columns=[f"{model_name}_representation"])
# pylint: disable=unsubscriptable-object
result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold]
result_df = result_df.sort_values(
by=[f"{model_name}_{distance_metric}"], ascending=True
).reset_index(drop=True)
resp_obj.append(result_df)
# -----------------------------------
toc = time.time()
if not silent:
logger.info(f"find function lasts {toc - tic} seconds")
return resp_obj
def represent(
img_path: Union[str, np.ndarray],
@ -714,64 +308,14 @@ def represent(
"face_confidence": float
}
"""
resp_objs = []
model = build_model(model_name)
# ---------------------------------
# we have run pre-process in verification. so, this can be skipped if it is coming from verify.
target_size = functions.find_target_size(model_name=model_name)
if detector_backend != "skip":
img_objs = functions.extract_faces(
img=img_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
)
else: # skip
# Try load. If load error, will raise exception internal
img, _ = functions.load_image(img_path)
# --------------------------------
if len(img.shape) == 4:
img = img[0] # e.g. (1, 224, 224, 3) to (224, 224, 3)
if len(img.shape) == 3:
img = cv2.resize(img, target_size)
img = np.expand_dims(img, axis=0)
# when called from verify, this is already normalized. But needed when user given.
if img.max() > 1:
img = img.astype(np.float32) / 255.0
# --------------------------------
# make dummy region and confidence to keep compatibility with `extract_faces`
img_region = {"x": 0, "y": 0, "w": img.shape[1], "h": img.shape[2]}
img_objs = [(img, img_region, 0)]
# ---------------------------------
for img, region, confidence in img_objs:
# custom normalization
img = functions.normalize_input(img=img, normalization=normalization)
# represent
# if "keras" in str(type(model)):
if isinstance(model, Model):
# model.predict causes memory issue when it is called in a for loop
# embedding = model.predict(img, verbose=0)[0].tolist()
embedding = model(img, training=False).numpy()[0].tolist()
# if you still get verbose logging. try call
# - `tf.keras.utils.disable_interactive_logging()`
# in your main program
else:
# SFace and Dlib are not keras models and no verbose arguments
embedding = model.predict(img)[0].tolist()
resp_obj = {}
resp_obj["embedding"] = embedding
resp_obj["facial_area"] = region
resp_obj["face_confidence"] = confidence
resp_objs.append(resp_obj)
return resp_objs
return representation.represent(
img_path=img_path,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend=detector_backend,
align=align,
normalization=normalization,
)
def stream(
@ -807,23 +351,15 @@ def stream(
"""
if time_threshold < 1:
raise ValueError(
"time_threshold must be greater than the value 1 but you passed " + str(time_threshold)
)
if frame_threshold < 1:
raise ValueError(
"frame_threshold must be greater than the value 1 but you passed "
+ str(frame_threshold)
)
time_threshold = max(time_threshold, 1)
frame_threshold = max(frame_threshold, 1)
realtime.analysis(
db_path,
model_name,
detector_backend,
distance_metric,
enable_face_analysis,
db_path=db_path,
model_name=model_name,
detector_backend=detector_backend,
distance_metric=distance_metric,
enable_face_analysis=enable_face_analysis,
source=source,
time_threshold=time_threshold,
frame_threshold=frame_threshold,
@ -867,31 +403,15 @@ def extract_faces(
"""
resp_objs = []
img_objs = functions.extract_faces(
img=img_path,
return detection.extract_faces(
img_path=img_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=grayscale,
enforce_detection=enforce_detection,
align=align,
grayscale=grayscale,
)
for img, region, confidence in img_objs:
resp_obj = {}
# discard expanded dimension
if len(img.shape) == 4:
img = img[0]
resp_obj["face"] = img[:, :, ::-1]
resp_obj["facial_area"] = region
resp_obj["confidence"] = confidence
resp_objs.append(resp_obj)
return resp_objs
# ---------------------------
# deprecated functions
@ -904,7 +424,7 @@ def detectFace(
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,
) -> np.ndarray:
) -> Union[np.ndarray, None]:
"""
Deprecated function. Use extract_faces for same functionality.

View File

View File

@ -0,0 +1,184 @@
# built-in dependencies
from typing import Any, Dict, List, Union
# 3rd party dependencies
import numpy as np
from tqdm import tqdm
import cv2
# project dependencies
from deepface.modules import modeling
from deepface.commons import functions
from deepface.extendedmodels import Age, Gender, Race, Emotion
def analyze(
img_path: Union[str, np.ndarray],
actions: Union[tuple, list] = ("emotion", "age", "gender", "race"),
enforce_detection: bool = True,
detector_backend: str = "opencv",
align: bool = True,
silent: bool = False,
) -> List[Dict[str, Any]]:
"""
This function analyzes facial attributes including age, gender, emotion and race.
In the background, analysis function builds convolutional neural network models to
classify age, gender, emotion and race of the input image.
Parameters:
img_path: exact image path, numpy array (BGR) or base64 encoded image could be passed.
If source image has more than one face, then result will be size of number of faces
appearing in the image.
actions (tuple): The default is ('age', 'gender', 'emotion', 'race'). You can drop
some of those attributes.
enforce_detection (bool): The function throws exception if no face detected by default.
Set this to False if you don't want to get exception. This might be convenient for low
resolution images.
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
dlib, mediapipe or yolov8.
align (boolean): alignment according to the eye positions.
silent (boolean): disable (some) log messages
Returns:
The function returns a list of dictionaries for each face appearing in the image.
[
{
"region": {'x': 230, 'y': 120, 'w': 36, 'h': 45},
"age": 28.66,
'face_confidence': 0.9993908405303955,
"dominant_gender": "Woman",
"gender": {
'Woman': 99.99407529830933,
'Man': 0.005928758764639497,
}
"dominant_emotion": "neutral",
"emotion": {
'sad': 37.65260875225067,
'angry': 0.15512987738475204,
'surprise': 0.0022171278033056296,
'fear': 1.2489334680140018,
'happy': 4.609785228967667,
'disgust': 9.698561953541684e-07,
'neutral': 56.33133053779602
}
"dominant_race": "white",
"race": {
'indian': 0.5480832420289516,
'asian': 0.7830780930817127,
'latino hispanic': 2.0677512511610985,
'black': 0.06337375962175429,
'middle eastern': 3.088453598320484,
'white': 93.44925880432129
}
}
]
"""
# ---------------------------------
# validate actions
if isinstance(actions, str):
actions = (actions,)
# check if actions is not an iterable or empty.
if not hasattr(actions, "__getitem__") or not actions:
raise ValueError("`actions` must be a list of strings.")
actions = list(actions)
# For each action, check if it is valid
for action in actions:
if action not in ("emotion", "age", "gender", "race"):
raise ValueError(
f"Invalid action passed ({repr(action)})). "
"Valid actions are `emotion`, `age`, `gender`, `race`."
)
# ---------------------------------
resp_objects = []
img_objs = functions.extract_faces(
img=img_path,
target_size=(224, 224),
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
)
for img_content, img_region, img_confidence in img_objs:
if img_content.shape[0] > 0 and img_content.shape[1] > 0:
obj = {}
# facial attribute analysis
pbar = tqdm(
range(0, len(actions)),
desc="Finding actions",
disable=silent if len(actions) > 1 else True,
)
for index in pbar:
action = actions[index]
pbar.set_description(f"Action: {action}")
if action == "emotion":
img_gray = cv2.cvtColor(img_content[0], cv2.COLOR_BGR2GRAY)
img_gray = cv2.resize(img_gray, (48, 48))
img_gray = np.expand_dims(img_gray, axis=0)
emotion_predictions = modeling.build_model("Emotion").predict(
img_gray, verbose=0
)[0, :]
sum_of_predictions = emotion_predictions.sum()
obj["emotion"] = {}
for i, emotion_label in enumerate(Emotion.labels):
emotion_prediction = 100 * emotion_predictions[i] / sum_of_predictions
obj["emotion"][emotion_label] = emotion_prediction
obj["dominant_emotion"] = Emotion.labels[np.argmax(emotion_predictions)]
elif action == "age":
age_predictions = modeling.build_model("Age").predict(img_content, verbose=0)[
0, :
]
apparent_age = Age.findApparentAge(age_predictions)
# int cast is for exception - object of type 'float32' is not JSON serializable
obj["age"] = int(apparent_age)
elif action == "gender":
gender_predictions = modeling.build_model("Gender").predict(
img_content, verbose=0
)[0, :]
obj["gender"] = {}
for i, gender_label in enumerate(Gender.labels):
gender_prediction = 100 * gender_predictions[i]
obj["gender"][gender_label] = gender_prediction
obj["dominant_gender"] = Gender.labels[np.argmax(gender_predictions)]
elif action == "race":
race_predictions = modeling.build_model("Race").predict(img_content, verbose=0)[
0, :
]
sum_of_predictions = race_predictions.sum()
obj["race"] = {}
for i, race_label in enumerate(Race.labels):
race_prediction = 100 * race_predictions[i] / sum_of_predictions
obj["race"][race_label] = race_prediction
obj["dominant_race"] = Race.labels[np.argmax(race_predictions)]
# -----------------------------
# mention facial areas
obj["region"] = img_region
# include image confidence
obj["face_confidence"] = img_confidence
resp_objects.append(obj)
return resp_objects

View File

@ -0,0 +1,72 @@
# built-in dependencies
from typing import Any, Dict, List, Tuple, Union
# 3rd part dependencies
import numpy as np
# project dependencies
from deepface.commons import functions
def extract_faces(
img_path: Union[str, np.ndarray],
target_size: Tuple[int, int] = (224, 224),
detector_backend: str = "opencv",
enforce_detection: bool = True,
align: bool = True,
grayscale: bool = False,
) -> List[Dict[str, Any]]:
"""
This function applies pre-processing stages of a face recognition pipeline
including detection and alignment
Parameters:
img_path: exact image path, numpy array (BGR) or base64 encoded image.
Source image can have many face. Then, result will be the size of number
of faces appearing in that source image.
target_size (tuple): final shape of facial image. black pixels will be
added to resize the image.
detector_backend (string): face detection backends are retinaface, mtcnn,
opencv, ssd or dlib
enforce_detection (boolean): function throws exception if face cannot be
detected in the fed image. Set this to False if you do not want to get
an exception and run the function anyway.
align (boolean): alignment according to the eye positions.
grayscale (boolean): extracting faces in rgb or gray scale
Returns:
list of dictionaries. Each dictionary will have facial image itself (RGB),
extracted area from the original image and confidence score.
"""
resp_objs = []
img_objs = functions.extract_faces(
img=img_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=grayscale,
enforce_detection=enforce_detection,
align=align,
)
for img, region, confidence in img_objs:
resp_obj = {}
# discard expanded dimension
if len(img.shape) == 4:
img = img[0]
# bgr to rgb
resp_obj["face"] = img[:, :, ::-1]
resp_obj["facial_area"] = region
resp_obj["confidence"] = confidence
resp_objs.append(resp_obj)
return resp_objs

View File

@ -0,0 +1,71 @@
# built-in dependencies
from typing import Any, Union
# 3rd party dependencies
import tensorflow as tf
# project dependencies
from deepface.basemodels import (
VGGFace,
OpenFace,
Facenet,
Facenet512,
FbDeepFace,
DeepID,
DlibWrapper,
ArcFace,
SFace,
)
from deepface.extendedmodels import Age, Gender, Race, Emotion
# conditional dependencies
tf_version = int(tf.__version__.split(".", maxsplit=1)[0])
if tf_version == 2:
from tensorflow.keras.models import Model
else:
from keras.models import Model
def build_model(model_name: str) -> Union[Model, Any]:
"""
This function builds a deepface model
Parameters:
model_name (string): face recognition or facial attribute model
VGG-Face, Facenet, OpenFace, DeepFace, DeepID for face recognition
Age, Gender, Emotion, Race for facial attributes
Returns:
built deepface model ( (tf.)keras.models.Model )
"""
# singleton design pattern
global model_obj
models = {
"VGG-Face": VGGFace.loadModel,
"OpenFace": OpenFace.loadModel,
"Facenet": Facenet.loadModel,
"Facenet512": Facenet512.loadModel,
"DeepFace": FbDeepFace.loadModel,
"DeepID": DeepID.loadModel,
"Dlib": DlibWrapper.loadModel,
"ArcFace": ArcFace.loadModel,
"SFace": SFace.load_model,
"Emotion": Emotion.loadModel,
"Age": Age.loadModel,
"Gender": Gender.loadModel,
"Race": Race.loadModel,
}
if not "model_obj" in globals():
model_obj = {}
if not model_name in model_obj:
model = models.get(model_name)
if model:
model = model()
model_obj[model_name] = model
else:
raise ValueError(f"Invalid model_name passed - {model_name}")
return model_obj[model_name]

View File

@ -0,0 +1,268 @@
# built-in dependencies
import os
import pickle
from typing import List, Union
import time
# 3rd party dependencies
import numpy as np
import pandas as pd
from tqdm import tqdm
# project dependencies
from deepface.commons import functions, distance as dst
from deepface.commons.logger import Logger
from deepface.modules import representation
logger = Logger(module="deepface/modules/recognition.py")
def find(
img_path: Union[str, np.ndarray],
db_path: str,
model_name: str = "VGG-Face",
distance_metric: str = "cosine",
enforce_detection: bool = True,
detector_backend: str = "opencv",
align: bool = True,
normalization: str = "base",
silent: bool = False,
) -> List[pd.DataFrame]:
"""
This function applies verification several times and find the identities in a database
Parameters:
img_path: exact image path, numpy array (BGR) or based64 encoded image.
Source image can have many faces. Then, result will be the size of number of
faces in the source image.
db_path (string): You should store some image files in a folder and pass the
exact folder path to this. A database image can also have many faces.
Then, all detected faces in db side will be considered in the decision.
model_name (string): VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID,
Dlib, ArcFace, SFace or Ensemble
distance_metric (string): cosine, euclidean, euclidean_l2
enforce_detection (bool): The function throws exception if a face could not be detected.
Set this to False if you don't want to get exception. This might be convenient for low
resolution images.
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
dlib, mediapipe or yolov8.
align (boolean): alignment according to the eye positions.
normalization (string): normalize the input image before feeding to model
silent (boolean): disable some logging and progress bars
Returns:
This function returns list of pandas data frame. Each item of the list corresponding to
an identity in the img_path.
"""
tic = time.time()
# -------------------------------
if os.path.isdir(db_path) is not True:
raise ValueError("Passed db_path does not exist!")
target_size = functions.find_target_size(model_name=model_name)
# ---------------------------------------
file_name = f"representations_{model_name}.pkl"
file_name = file_name.replace("-", "_").lower()
df_cols = [
"identity",
f"{model_name}_representation",
"target_x",
"target_y",
"target_w",
"target_h",
]
if os.path.exists(db_path + "/" + file_name):
if not silent:
logger.warn(
f"Representations for images in {db_path} folder were previously stored"
f" in {file_name}. If you added new instances after the creation, then please "
"delete this file and call find function again. It will create it again."
)
with open(f"{db_path}/{file_name}", "rb") as f:
representations = pickle.load(f)
if len(representations) > 0 and len(representations[0]) != len(df_cols):
raise ValueError(
f"Seems existing {db_path}/{file_name} is out-of-the-date."
"Delete it and re-run."
)
if not silent:
logger.info(f"There are {len(representations)} representations found in {file_name}")
else: # create representation.pkl from scratch
employees = []
for r, _, f in os.walk(db_path):
for file in f:
if (
(".jpg" in file.lower())
or (".jpeg" in file.lower())
or (".png" in file.lower())
):
exact_path = r + "/" + file
employees.append(exact_path)
if len(employees) == 0:
raise ValueError(
"There is no image in ",
db_path,
" folder! Validate .jpg or .png files exist in this path.",
)
# ------------------------
# find representations for db images
representations = []
# for employee in employees:
pbar = tqdm(
range(0, len(employees)),
desc="Finding representations",
disable=silent,
)
for index in pbar:
employee = employees[index]
img_objs = functions.extract_faces(
img=employee,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
)
for img_content, img_region, _ in img_objs:
embedding_obj = representation.represent(
img_path=img_content,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
img_representation = embedding_obj[0]["embedding"]
instance = []
instance.append(employee)
instance.append(img_representation)
instance.append(img_region["x"])
instance.append(img_region["y"])
instance.append(img_region["w"])
instance.append(img_region["h"])
representations.append(instance)
# -------------------------------
with open(f"{db_path}/{file_name}", "wb") as f:
pickle.dump(representations, f)
if not silent:
logger.info(
f"Representations stored in {db_path}/{file_name} file."
+ "Please delete this file when you add new identities in your database."
)
# ----------------------------
# now, we got representations for facial database
df = pd.DataFrame(
representations,
columns=df_cols,
)
# img path might have more than once face
source_objs = functions.extract_faces(
img=img_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
)
resp_obj = []
for source_img, source_region, _ in source_objs:
target_embedding_obj = representation.represent(
img_path=source_img,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
target_representation = target_embedding_obj[0]["embedding"]
result_df = df.copy() # df will be filtered in each img
result_df["source_x"] = source_region["x"]
result_df["source_y"] = source_region["y"]
result_df["source_w"] = source_region["w"]
result_df["source_h"] = source_region["h"]
distances = []
for index, instance in df.iterrows():
source_representation = instance[f"{model_name}_representation"]
target_dims = len(list(target_representation))
source_dims = len(list(source_representation))
if target_dims != source_dims:
raise ValueError(
"Source and target embeddings must have same dimensions but "
+ f"{target_dims}:{source_dims}. Model structure may change"
+ " after pickle created. Delete the {file_name} and re-run."
)
if distance_metric == "cosine":
distance = dst.findCosineDistance(source_representation, target_representation)
elif distance_metric == "euclidean":
distance = dst.findEuclideanDistance(source_representation, target_representation)
elif distance_metric == "euclidean_l2":
distance = dst.findEuclideanDistance(
dst.l2_normalize(source_representation),
dst.l2_normalize(target_representation),
)
else:
raise ValueError(f"invalid distance metric passes - {distance_metric}")
distances.append(distance)
# ---------------------------
result_df[f"{model_name}_{distance_metric}"] = distances
threshold = dst.findThreshold(model_name, distance_metric)
result_df = result_df.drop(columns=[f"{model_name}_representation"])
# pylint: disable=unsubscriptable-object
result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold]
result_df = result_df.sort_values(
by=[f"{model_name}_{distance_metric}"], ascending=True
).reset_index(drop=True)
resp_obj.append(result_df)
# -----------------------------------
toc = time.time()
if not silent:
logger.info(f"find function lasts {toc - tic} seconds")
return resp_obj

View File

@ -0,0 +1,129 @@
# built-in dependencies
from typing import Any, Dict, List, Union
# 3rd party dependencies
import numpy as np
import cv2
import tensorflow as tf
# project dependencies
from deepface.modules import modeling
from deepface.commons import functions
# conditional dependencies
tf_version = int(tf.__version__.split(".", maxsplit=1)[0])
if tf_version == 2:
from tensorflow.keras.models import Model
else:
from keras.models import Model
def represent(
img_path: Union[str, np.ndarray],
model_name: str = "VGG-Face",
enforce_detection: bool = True,
detector_backend: str = "opencv",
align: bool = True,
normalization: str = "base",
) -> List[Dict[str, Any]]:
"""
This function represents facial images as vectors. The function uses convolutional neural
networks models to generate vector embeddings.
Parameters:
img_path (string): exact image path. Alternatively, numpy array (BGR) or based64
encoded images could be passed. Source image can have many faces. Then, result will
be the size of number of faces appearing in the source image.
model_name (string): VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
ArcFace, SFace
enforce_detection (boolean): If no face could not be detected in an image, then this
function will return exception by default. Set this to False not to have this exception.
This might be convenient for low resolution images.
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
dlib, mediapipe or yolov8. A special value `skip` could be used to skip face-detection
and only encode the given image.
align (boolean): alignment according to the eye positions.
normalization (string): normalize the input image before feeding to model
Returns:
Represent function returns a list of object, each object has fields as follows:
{
// Multidimensional vector
// The number of dimensions is changing based on the reference model.
// E.g. FaceNet returns 128 dimensional vector;
// VGG-Face returns 2622 dimensional vector.
"embedding": np.array,
// Detected Facial-Area by Face detection in dict format.
// (x, y) is left-corner point, and (w, h) is the width and height
// If `detector_backend` == `skip`, it is the full image area and nonsense.
"facial_area": dict{"x": int, "y": int, "w": int, "h": int},
// Face detection confidence.
// If `detector_backend` == `skip`, will be 0 and nonsense.
"face_confidence": float
}
"""
resp_objs = []
model = modeling.build_model(model_name)
# ---------------------------------
# we have run pre-process in verification. so, this can be skipped if it is coming from verify.
target_size = functions.find_target_size(model_name=model_name)
if detector_backend != "skip":
img_objs = functions.extract_faces(
img=img_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
)
else: # skip
# Try load. If load error, will raise exception internal
img, _ = functions.load_image(img_path)
# --------------------------------
if len(img.shape) == 4:
img = img[0] # e.g. (1, 224, 224, 3) to (224, 224, 3)
if len(img.shape) == 3:
img = cv2.resize(img, target_size)
img = np.expand_dims(img, axis=0)
# when called from verify, this is already normalized. But needed when user given.
if img.max() > 1:
img = (img.astype(np.float32) / 255.0).astype(np.float32)
# --------------------------------
# make dummy region and confidence to keep compatibility with `extract_faces`
img_region = {"x": 0, "y": 0, "w": img.shape[1], "h": img.shape[2]}
img_objs = [(img, img_region, 0)]
# ---------------------------------
for img, region, confidence in img_objs:
# custom normalization
img = functions.normalize_input(img=img, normalization=normalization)
# represent
# if "keras" in str(type(model)):
if isinstance(model, Model):
# model.predict causes memory issue when it is called in a for loop
# embedding = model.predict(img, verbose=0)[0].tolist()
embedding = model(img, training=False).numpy()[0].tolist()
# if you still get verbose logging. try call
# - `tf.keras.utils.disable_interactive_logging()`
# in your main program
else:
# SFace and Dlib are not keras models and no verbose arguments
embedding = model.predict(img)[0].tolist()
resp_obj = {}
resp_obj["embedding"] = embedding
resp_obj["facial_area"] = region
resp_obj["face_confidence"] = confidence
resp_objs.append(resp_obj)
return resp_objs

View File

@ -0,0 +1,151 @@
# built-in dependencies
import time
from typing import Any, Dict, Union
# 3rd party dependencies
import numpy as np
# project dependencies
from deepface.commons import functions, distance as dst
from deepface.modules import representation
def verify(
img1_path: Union[str, np.ndarray],
img2_path: Union[str, np.ndarray],
model_name: str = "VGG-Face",
detector_backend: str = "opencv",
distance_metric: str = "cosine",
enforce_detection: bool = True,
align: bool = True,
normalization: str = "base",
) -> Dict[str, Any]:
"""
This function verifies an image pair is same person or different persons. In the background,
verification function represents facial images as vectors and then calculates the similarity
between those vectors. Vectors of same person images should have more similarity (or less
distance) than vectors of different persons.
Parameters:
img1_path, img2_path: exact image path as string. numpy array (BGR) or based64 encoded
images are also welcome. If one of pair has more than one face, then we will compare the
face pair with max similarity.
model_name (str): VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib
, ArcFace and SFace
distance_metric (string): cosine, euclidean, euclidean_l2
enforce_detection (boolean): If no face could not be detected in an image, then this
function will return exception by default. Set this to False not to have this exception.
This might be convenient for low resolution images.
detector_backend (string): set face detector backend to opencv, retinaface, mtcnn, ssd,
dlib, mediapipe or yolov8.
align (boolean): alignment according to the eye positions.
normalization (string): normalize the input image before feeding to model
Returns:
Verify function returns a dictionary.
{
"verified": True
, "distance": 0.2563
, "max_threshold_to_verify": 0.40
, "model": "VGG-Face"
, "similarity_metric": "cosine"
, 'facial_areas': {
'img1': {'x': 345, 'y': 211, 'w': 769, 'h': 769},
'img2': {'x': 318, 'y': 534, 'w': 779, 'h': 779}
}
, "time": 2
}
"""
tic = time.time()
# --------------------------------
target_size = functions.find_target_size(model_name=model_name)
# img pairs might have many faces
img1_objs = functions.extract_faces(
img=img1_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
)
img2_objs = functions.extract_faces(
img=img2_path,
target_size=target_size,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
)
# --------------------------------
distances = []
regions = []
# now we will find the face pair with minimum distance
for img1_content, img1_region, _ in img1_objs:
for img2_content, img2_region, _ in img2_objs:
img1_embedding_obj = representation.represent(
img_path=img1_content,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
img2_embedding_obj = representation.represent(
img_path=img2_content,
model_name=model_name,
enforce_detection=enforce_detection,
detector_backend="skip",
align=align,
normalization=normalization,
)
img1_representation = img1_embedding_obj[0]["embedding"]
img2_representation = img2_embedding_obj[0]["embedding"]
if distance_metric == "cosine":
distance = dst.findCosineDistance(img1_representation, img2_representation)
elif distance_metric == "euclidean":
distance = dst.findEuclideanDistance(img1_representation, img2_representation)
elif distance_metric == "euclidean_l2":
distance = dst.findEuclideanDistance(
dst.l2_normalize(img1_representation), dst.l2_normalize(img2_representation)
)
else:
raise ValueError("Invalid distance_metric passed - ", distance_metric)
distances.append(distance)
regions.append((img1_region, img2_region))
# -------------------------------
threshold = dst.findThreshold(model_name, distance_metric)
distance = min(distances) # best distance
facial_areas = regions[np.argmin(distances)]
toc = time.time()
# pylint: disable=simplifiable-if-expression
resp_obj = {
"verified": True if distance <= threshold else False,
"distance": distance,
"threshold": threshold,
"model": model_name,
"detector_backend": detector_backend,
"similarity_metric": distance_metric,
"facial_areas": {"img1": facial_areas[0], "img2": facial_areas[1]},
"time": round(toc - tic, 2),
}
return resp_obj

View File

@ -1,3 +1,4 @@
import cv2
from deepface import DeepFace
from deepface.commons.logger import Logger
@ -14,7 +15,7 @@ def test_standard_represent():
logger.info("✅ test standard represent function done")
def test_represent_for_skipped_detector_backend():
def test_represent_for_skipped_detector_backend_with_image_path():
face_img = "dataset/img5.jpg"
img_objs = DeepFace.represent(img_path=face_img, detector_backend="skip")
assert len(img_objs) >= 1
@ -27,4 +28,21 @@ def test_represent_for_skipped_detector_backend():
assert "w" in img_obj["facial_area"].keys()
assert "h" in img_obj["facial_area"].keys()
assert "face_confidence" in img_obj.keys()
logger.info("✅ test represent function for skipped detector backend done")
logger.info("✅ test represent function for skipped detector and image path input backend done")
def test_represent_for_skipped_detector_backend_with_preloaded_image():
face_img = "dataset/img5.jpg"
img = cv2.imread(face_img)
img_objs = DeepFace.represent(img_path=img, detector_backend="skip")
assert len(img_objs) >= 1
img_obj = img_objs[0]
assert "embedding" in img_obj.keys()
assert "facial_area" in img_obj.keys()
assert isinstance(img_obj["facial_area"], dict)
assert "x" in img_obj["facial_area"].keys()
assert "y" in img_obj["facial_area"].keys()
assert "w" in img_obj["facial_area"].keys()
assert "h" in img_obj["facial_area"].keys()
assert "face_confidence" in img_obj.keys()
logger.info("✅ test represent function for skipped detector and preloaded image done")