Merge pull request #1439 from Raghucharan16/master

This commit is contained in:
Sefik Ilkin Serengil 2025-03-01 09:46:38 +00:00 committed by GitHub
commit 2d31f57d45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 110 additions and 5 deletions

View File

@ -54,7 +54,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
Args:
model_name (str): model identifier
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
ArcFace, SFace and GhostFaceNet for face recognition
ArcFace, SFace GhostFaceNet and Buffalo_L for face recognition
- Age, Gender, Emotion, Race for facial attributes
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n,
yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors

View File

@ -0,0 +1,96 @@
import os
from typing import List, Union
import numpy as np
from deepface.commons import weight_utils, folder_utils
from deepface.commons.logger import Logger
from deepface.models.FacialRecognition import FacialRecognition
logger = Logger()
class Buffalo_L(FacialRecognition):
def __init__(self):
self.model = None
self.input_shape = (112, 112)
self.output_shape = 512
self.load_model()
def load_model(self):
"""Load the InsightFace Buffalo_L recognition model."""
try:
from insightface.model_zoo import get_model
except Exception as err:
raise ModuleNotFoundError(
"InsightFace and its dependencies are optional for the Buffalo_L model. "
"Please install them with: "
"pip install insightface>=0.7.3 onnxruntime>=1.9.0 typing-extensions pydantic"
"albumentations"
) from err
sub_dir = "buffalo_l"
model_file = "webface_r50.onnx"
model_rel_path = os.path.join(sub_dir, model_file)
home = folder_utils.get_deepface_home()
weights_dir = os.path.join(home, ".deepface", "weights")
buffalo_l_dir = os.path.join(weights_dir, sub_dir)
if not os.path.exists(buffalo_l_dir):
os.makedirs(buffalo_l_dir, exist_ok=True)
logger.info(f"Created directory: {buffalo_l_dir}")
weights_path = weight_utils.download_weights_if_necessary(
file_name=model_rel_path,
source_url="https://drive.google.com/uc?export=download&confirm=pbef&id=1N0GL-8ehw_bz2eZQWz2b0A5XBdXdxZhg" #pylint: disable=line-too-long
)
if not os.path.exists(weights_path):
raise FileNotFoundError(f"Model file not found at: {weights_path}")
logger.debug(f"Model file found at: {weights_path}")
self.model = get_model(weights_path)
self.model.prepare(ctx_id=-1, input_size=self.input_shape)
def preprocess(self, img: np.ndarray) -> np.ndarray:
"""
Preprocess the input image or batch of images.
Args:
img: Input image or batch with shape (112, 112, 3)
or (batch_size, 112, 112, 3).
Returns:
Preprocessed image(s) with RGB converted to BGR.
"""
if len(img.shape) == 3:
img = np.expand_dims(img, axis=0) # Convert single image to batch of 1
elif len(img.shape) != 4:
raise ValueError(f"Input must be (112, 112, 3) or (X, 112, 112, 3). Got {img.shape}")
# Convert RGB to BGR for the entire batch
img = img[:, :, :, ::-1]
return img
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
"""
Extract facial embedding(s) from the input image or batch of images.
Args:
img: Input image or batch with shape (112, 112, 3)
or (batch_size, 112, 112, 3).
Returns:
Embedding as a list of floats (single image)
or list of lists of floats (batch).
"""
# Preprocess the input (single image or batch)
img = self.preprocess(img)
batch_size = img.shape[0]
# Handle both single images and batches
embeddings = []
for i in range(batch_size):
embedding = self.model.get_feat(img[i])
embeddings.append(embedding.flatten().tolist())
# Return single embedding if batch_size is 1, otherwise return list of embeddings
return embeddings[0] if batch_size == 1 else embeddings

View File

@ -11,7 +11,8 @@ from deepface.models.facial_recognition import (
SFace,
Dlib,
Facenet,
GhostFaceNet
GhostFaceNet,
Buffalo_L
)
from deepface.models.face_detection import (
FastMtCnn,
@ -59,7 +60,8 @@ def build_model(task: str, model_name: str) -> Any:
"Dlib": Dlib.DlibClient,
"ArcFace": ArcFace.ArcFaceClient,
"SFace": SFace.SFaceClient,
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient,
"Buffalo_L": Buffalo_L.Buffalo_L
},
"spoofing": {
"Fasnet": FasNet.Fasnet,

View File

@ -423,6 +423,7 @@ def find_threshold(model_name: str, distance_metric: str) -> float:
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64},
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17},
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10},
"Buffalo_L": {"cosine": 0.65, "euclidean": 11.13, "euclidean_l2": 1.1},
}
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)

View File

@ -12,4 +12,4 @@ flask_cors>=4.0.1
mtcnn>=0.1.0
retina-face>=0.0.14
fire>=0.4.0
gunicorn>=20.1.0
gunicorn>=20.1.0

View File

@ -3,4 +3,10 @@ mediapipe>=0.8.7.3
dlib>=19.20.0
ultralytics>=8.0.122
facenet-pytorch>=2.5.3
torch>=2.1.2
torch>=2.1.2
insightface>=0.7.3
onnxruntime>=1.9.0
tf-keras
typing-extensions
pydantic
albumentations