mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 19:45:21 +00:00
Merge pull request #1439 from Raghucharan16/master
This commit is contained in:
commit
2d31f57d45
@ -54,7 +54,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
|
|||||||
Args:
|
Args:
|
||||||
model_name (str): model identifier
|
model_name (str): model identifier
|
||||||
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||||
ArcFace, SFace and GhostFaceNet for face recognition
|
ArcFace, SFace GhostFaceNet and Buffalo_L for face recognition
|
||||||
- Age, Gender, Emotion, Race for facial attributes
|
- Age, Gender, Emotion, Race for facial attributes
|
||||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n,
|
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n,
|
||||||
yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors
|
yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors
|
||||||
|
96
deepface/models/facial_recognition/Buffalo_L.py
Normal file
96
deepface/models/facial_recognition/Buffalo_L.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Union
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from deepface.commons import weight_utils, folder_utils
|
||||||
|
from deepface.commons.logger import Logger
|
||||||
|
from deepface.models.FacialRecognition import FacialRecognition
|
||||||
|
|
||||||
|
logger = Logger()
|
||||||
|
|
||||||
|
class Buffalo_L(FacialRecognition):
|
||||||
|
def __init__(self):
|
||||||
|
self.model = None
|
||||||
|
self.input_shape = (112, 112)
|
||||||
|
self.output_shape = 512
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
"""Load the InsightFace Buffalo_L recognition model."""
|
||||||
|
try:
|
||||||
|
from insightface.model_zoo import get_model
|
||||||
|
except Exception as err:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"InsightFace and its dependencies are optional for the Buffalo_L model. "
|
||||||
|
"Please install them with: "
|
||||||
|
"pip install insightface>=0.7.3 onnxruntime>=1.9.0 typing-extensions pydantic"
|
||||||
|
"albumentations"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
sub_dir = "buffalo_l"
|
||||||
|
model_file = "webface_r50.onnx"
|
||||||
|
model_rel_path = os.path.join(sub_dir, model_file)
|
||||||
|
home = folder_utils.get_deepface_home()
|
||||||
|
weights_dir = os.path.join(home, ".deepface", "weights")
|
||||||
|
buffalo_l_dir = os.path.join(weights_dir, sub_dir)
|
||||||
|
|
||||||
|
if not os.path.exists(buffalo_l_dir):
|
||||||
|
os.makedirs(buffalo_l_dir, exist_ok=True)
|
||||||
|
logger.info(f"Created directory: {buffalo_l_dir}")
|
||||||
|
|
||||||
|
weights_path = weight_utils.download_weights_if_necessary(
|
||||||
|
file_name=model_rel_path,
|
||||||
|
source_url="https://drive.google.com/uc?export=download&confirm=pbef&id=1N0GL-8ehw_bz2eZQWz2b0A5XBdXdxZhg" #pylint: disable=line-too-long
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.exists(weights_path):
|
||||||
|
raise FileNotFoundError(f"Model file not found at: {weights_path}")
|
||||||
|
logger.debug(f"Model file found at: {weights_path}")
|
||||||
|
|
||||||
|
self.model = get_model(weights_path)
|
||||||
|
self.model.prepare(ctx_id=-1, input_size=self.input_shape)
|
||||||
|
|
||||||
|
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Preprocess the input image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: Input image or batch with shape (112, 112, 3)
|
||||||
|
or (batch_size, 112, 112, 3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed image(s) with RGB converted to BGR.
|
||||||
|
"""
|
||||||
|
if len(img.shape) == 3:
|
||||||
|
img = np.expand_dims(img, axis=0) # Convert single image to batch of 1
|
||||||
|
elif len(img.shape) != 4:
|
||||||
|
raise ValueError(f"Input must be (112, 112, 3) or (X, 112, 112, 3). Got {img.shape}")
|
||||||
|
# Convert RGB to BGR for the entire batch
|
||||||
|
img = img[:, :, :, ::-1]
|
||||||
|
return img
|
||||||
|
|
||||||
|
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
|
||||||
|
"""
|
||||||
|
Extract facial embedding(s) from the input image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: Input image or batch with shape (112, 112, 3)
|
||||||
|
or (batch_size, 112, 112, 3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding as a list of floats (single image)
|
||||||
|
or list of lists of floats (batch).
|
||||||
|
"""
|
||||||
|
# Preprocess the input (single image or batch)
|
||||||
|
img = self.preprocess(img)
|
||||||
|
batch_size = img.shape[0]
|
||||||
|
|
||||||
|
# Handle both single images and batches
|
||||||
|
embeddings = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
embedding = self.model.get_feat(img[i])
|
||||||
|
embeddings.append(embedding.flatten().tolist())
|
||||||
|
|
||||||
|
# Return single embedding if batch_size is 1, otherwise return list of embeddings
|
||||||
|
return embeddings[0] if batch_size == 1 else embeddings
|
||||||
|
|
@ -11,7 +11,8 @@ from deepface.models.facial_recognition import (
|
|||||||
SFace,
|
SFace,
|
||||||
Dlib,
|
Dlib,
|
||||||
Facenet,
|
Facenet,
|
||||||
GhostFaceNet
|
GhostFaceNet,
|
||||||
|
Buffalo_L
|
||||||
)
|
)
|
||||||
from deepface.models.face_detection import (
|
from deepface.models.face_detection import (
|
||||||
FastMtCnn,
|
FastMtCnn,
|
||||||
@ -59,7 +60,8 @@ def build_model(task: str, model_name: str) -> Any:
|
|||||||
"Dlib": Dlib.DlibClient,
|
"Dlib": Dlib.DlibClient,
|
||||||
"ArcFace": ArcFace.ArcFaceClient,
|
"ArcFace": ArcFace.ArcFaceClient,
|
||||||
"SFace": SFace.SFaceClient,
|
"SFace": SFace.SFaceClient,
|
||||||
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient
|
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient,
|
||||||
|
"Buffalo_L": Buffalo_L.Buffalo_L
|
||||||
},
|
},
|
||||||
"spoofing": {
|
"spoofing": {
|
||||||
"Fasnet": FasNet.Fasnet,
|
"Fasnet": FasNet.Fasnet,
|
||||||
|
@ -423,6 +423,7 @@ def find_threshold(model_name: str, distance_metric: str) -> float:
|
|||||||
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64},
|
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64},
|
||||||
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17},
|
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17},
|
||||||
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10},
|
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10},
|
||||||
|
"Buffalo_L": {"cosine": 0.65, "euclidean": 11.13, "euclidean_l2": 1.1},
|
||||||
}
|
}
|
||||||
|
|
||||||
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)
|
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)
|
||||||
|
@ -4,3 +4,9 @@ dlib>=19.20.0
|
|||||||
ultralytics>=8.0.122
|
ultralytics>=8.0.122
|
||||||
facenet-pytorch>=2.5.3
|
facenet-pytorch>=2.5.3
|
||||||
torch>=2.1.2
|
torch>=2.1.2
|
||||||
|
insightface>=0.7.3
|
||||||
|
onnxruntime>=1.9.0
|
||||||
|
tf-keras
|
||||||
|
typing-extensions
|
||||||
|
pydantic
|
||||||
|
albumentations
|
Loading…
x
Reference in New Issue
Block a user