mirror of
https://github.com/serengil/deepface.git
synced 2025-06-05 19:15:23 +00:00
Merge pull request #1439 from Raghucharan16/master
This commit is contained in:
commit
2d31f57d45
@ -54,7 +54,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
|
||||
Args:
|
||||
model_name (str): model identifier
|
||||
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||
ArcFace, SFace and GhostFaceNet for face recognition
|
||||
ArcFace, SFace GhostFaceNet and Buffalo_L for face recognition
|
||||
- Age, Gender, Emotion, Race for facial attributes
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n,
|
||||
yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors
|
||||
|
96
deepface/models/facial_recognition/Buffalo_L.py
Normal file
96
deepface/models/facial_recognition/Buffalo_L.py
Normal file
@ -0,0 +1,96 @@
|
||||
import os
|
||||
from typing import List, Union
|
||||
import numpy as np
|
||||
|
||||
from deepface.commons import weight_utils, folder_utils
|
||||
from deepface.commons.logger import Logger
|
||||
from deepface.models.FacialRecognition import FacialRecognition
|
||||
|
||||
logger = Logger()
|
||||
|
||||
class Buffalo_L(FacialRecognition):
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.input_shape = (112, 112)
|
||||
self.output_shape = 512
|
||||
self.load_model()
|
||||
|
||||
def load_model(self):
|
||||
"""Load the InsightFace Buffalo_L recognition model."""
|
||||
try:
|
||||
from insightface.model_zoo import get_model
|
||||
except Exception as err:
|
||||
raise ModuleNotFoundError(
|
||||
"InsightFace and its dependencies are optional for the Buffalo_L model. "
|
||||
"Please install them with: "
|
||||
"pip install insightface>=0.7.3 onnxruntime>=1.9.0 typing-extensions pydantic"
|
||||
"albumentations"
|
||||
) from err
|
||||
|
||||
sub_dir = "buffalo_l"
|
||||
model_file = "webface_r50.onnx"
|
||||
model_rel_path = os.path.join(sub_dir, model_file)
|
||||
home = folder_utils.get_deepface_home()
|
||||
weights_dir = os.path.join(home, ".deepface", "weights")
|
||||
buffalo_l_dir = os.path.join(weights_dir, sub_dir)
|
||||
|
||||
if not os.path.exists(buffalo_l_dir):
|
||||
os.makedirs(buffalo_l_dir, exist_ok=True)
|
||||
logger.info(f"Created directory: {buffalo_l_dir}")
|
||||
|
||||
weights_path = weight_utils.download_weights_if_necessary(
|
||||
file_name=model_rel_path,
|
||||
source_url="https://drive.google.com/uc?export=download&confirm=pbef&id=1N0GL-8ehw_bz2eZQWz2b0A5XBdXdxZhg" #pylint: disable=line-too-long
|
||||
)
|
||||
|
||||
if not os.path.exists(weights_path):
|
||||
raise FileNotFoundError(f"Model file not found at: {weights_path}")
|
||||
logger.debug(f"Model file found at: {weights_path}")
|
||||
|
||||
self.model = get_model(weights_path)
|
||||
self.model.prepare(ctx_id=-1, input_size=self.input_shape)
|
||||
|
||||
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Preprocess the input image or batch of images.
|
||||
|
||||
Args:
|
||||
img: Input image or batch with shape (112, 112, 3)
|
||||
or (batch_size, 112, 112, 3).
|
||||
|
||||
Returns:
|
||||
Preprocessed image(s) with RGB converted to BGR.
|
||||
"""
|
||||
if len(img.shape) == 3:
|
||||
img = np.expand_dims(img, axis=0) # Convert single image to batch of 1
|
||||
elif len(img.shape) != 4:
|
||||
raise ValueError(f"Input must be (112, 112, 3) or (X, 112, 112, 3). Got {img.shape}")
|
||||
# Convert RGB to BGR for the entire batch
|
||||
img = img[:, :, :, ::-1]
|
||||
return img
|
||||
|
||||
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Extract facial embedding(s) from the input image or batch of images.
|
||||
|
||||
Args:
|
||||
img: Input image or batch with shape (112, 112, 3)
|
||||
or (batch_size, 112, 112, 3).
|
||||
|
||||
Returns:
|
||||
Embedding as a list of floats (single image)
|
||||
or list of lists of floats (batch).
|
||||
"""
|
||||
# Preprocess the input (single image or batch)
|
||||
img = self.preprocess(img)
|
||||
batch_size = img.shape[0]
|
||||
|
||||
# Handle both single images and batches
|
||||
embeddings = []
|
||||
for i in range(batch_size):
|
||||
embedding = self.model.get_feat(img[i])
|
||||
embeddings.append(embedding.flatten().tolist())
|
||||
|
||||
# Return single embedding if batch_size is 1, otherwise return list of embeddings
|
||||
return embeddings[0] if batch_size == 1 else embeddings
|
||||
|
@ -11,7 +11,8 @@ from deepface.models.facial_recognition import (
|
||||
SFace,
|
||||
Dlib,
|
||||
Facenet,
|
||||
GhostFaceNet
|
||||
GhostFaceNet,
|
||||
Buffalo_L
|
||||
)
|
||||
from deepface.models.face_detection import (
|
||||
FastMtCnn,
|
||||
@ -59,7 +60,8 @@ def build_model(task: str, model_name: str) -> Any:
|
||||
"Dlib": Dlib.DlibClient,
|
||||
"ArcFace": ArcFace.ArcFaceClient,
|
||||
"SFace": SFace.SFaceClient,
|
||||
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient
|
||||
"GhostFaceNet": GhostFaceNet.GhostFaceNetClient,
|
||||
"Buffalo_L": Buffalo_L.Buffalo_L
|
||||
},
|
||||
"spoofing": {
|
||||
"Fasnet": FasNet.Fasnet,
|
||||
|
@ -423,6 +423,7 @@ def find_threshold(model_name: str, distance_metric: str) -> float:
|
||||
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64},
|
||||
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17},
|
||||
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10},
|
||||
"Buffalo_L": {"cosine": 0.65, "euclidean": 11.13, "euclidean_l2": 1.1},
|
||||
}
|
||||
|
||||
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)
|
||||
|
@ -12,4 +12,4 @@ flask_cors>=4.0.1
|
||||
mtcnn>=0.1.0
|
||||
retina-face>=0.0.14
|
||||
fire>=0.4.0
|
||||
gunicorn>=20.1.0
|
||||
gunicorn>=20.1.0
|
@ -3,4 +3,10 @@ mediapipe>=0.8.7.3
|
||||
dlib>=19.20.0
|
||||
ultralytics>=8.0.122
|
||||
facenet-pytorch>=2.5.3
|
||||
torch>=2.1.2
|
||||
torch>=2.1.2
|
||||
insightface>=0.7.3
|
||||
onnxruntime>=1.9.0
|
||||
tf-keras
|
||||
typing-extensions
|
||||
pydantic
|
||||
albumentations
|
Loading…
x
Reference in New Issue
Block a user