diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 1fc95b6..2f20a13 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -54,7 +54,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: Args: model_name (str): model identifier - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - ArcFace, SFace and GhostFaceNet for face recognition + ArcFace, SFace GhostFaceNet and Buffalo_L for face recognition - Age, Gender, Emotion, Race for facial attributes - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n, yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors diff --git a/deepface/models/facial_recognition/Buffalo_L.py b/deepface/models/facial_recognition/Buffalo_L.py new file mode 100644 index 0000000..eeca234 --- /dev/null +++ b/deepface/models/facial_recognition/Buffalo_L.py @@ -0,0 +1,96 @@ +import os +from typing import List, Union +import numpy as np + +from deepface.commons import weight_utils, folder_utils +from deepface.commons.logger import Logger +from deepface.models.FacialRecognition import FacialRecognition + +logger = Logger() + +class Buffalo_L(FacialRecognition): + def __init__(self): + self.model = None + self.input_shape = (112, 112) + self.output_shape = 512 + self.load_model() + + def load_model(self): + """Load the InsightFace Buffalo_L recognition model.""" + try: + from insightface.model_zoo import get_model + except Exception as err: + raise ModuleNotFoundError( + "InsightFace and its dependencies are optional for the Buffalo_L model. " + "Please install them with: " + "pip install insightface>=0.7.3 onnxruntime>=1.9.0 typing-extensions pydantic" + "albumentations" + ) from err + + sub_dir = "buffalo_l" + model_file = "webface_r50.onnx" + model_rel_path = os.path.join(sub_dir, model_file) + home = folder_utils.get_deepface_home() + weights_dir = os.path.join(home, ".deepface", "weights") + buffalo_l_dir = os.path.join(weights_dir, sub_dir) + + if not os.path.exists(buffalo_l_dir): + os.makedirs(buffalo_l_dir, exist_ok=True) + logger.info(f"Created directory: {buffalo_l_dir}") + + weights_path = weight_utils.download_weights_if_necessary( + file_name=model_rel_path, + source_url="https://drive.google.com/uc?export=download&confirm=pbef&id=1N0GL-8ehw_bz2eZQWz2b0A5XBdXdxZhg" #pylint: disable=line-too-long + ) + + if not os.path.exists(weights_path): + raise FileNotFoundError(f"Model file not found at: {weights_path}") + logger.debug(f"Model file found at: {weights_path}") + + self.model = get_model(weights_path) + self.model.prepare(ctx_id=-1, input_size=self.input_shape) + + def preprocess(self, img: np.ndarray) -> np.ndarray: + """ + Preprocess the input image or batch of images. + + Args: + img: Input image or batch with shape (112, 112, 3) + or (batch_size, 112, 112, 3). + + Returns: + Preprocessed image(s) with RGB converted to BGR. + """ + if len(img.shape) == 3: + img = np.expand_dims(img, axis=0) # Convert single image to batch of 1 + elif len(img.shape) != 4: + raise ValueError(f"Input must be (112, 112, 3) or (X, 112, 112, 3). Got {img.shape}") + # Convert RGB to BGR for the entire batch + img = img[:, :, :, ::-1] + return img + + def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]: + """ + Extract facial embedding(s) from the input image or batch of images. + + Args: + img: Input image or batch with shape (112, 112, 3) + or (batch_size, 112, 112, 3). + + Returns: + Embedding as a list of floats (single image) + or list of lists of floats (batch). + """ + # Preprocess the input (single image or batch) + img = self.preprocess(img) + batch_size = img.shape[0] + + # Handle both single images and batches + embeddings = [] + for i in range(batch_size): + embedding = self.model.get_feat(img[i]) + embeddings.append(embedding.flatten().tolist()) + + # Return single embedding if batch_size is 1, otherwise return list of embeddings + return embeddings[0] if batch_size == 1 else embeddings + \ No newline at end of file diff --git a/deepface/modules/modeling.py b/deepface/modules/modeling.py index 176d9e7..7ec067d 100644 --- a/deepface/modules/modeling.py +++ b/deepface/modules/modeling.py @@ -11,7 +11,8 @@ from deepface.models.facial_recognition import ( SFace, Dlib, Facenet, - GhostFaceNet + GhostFaceNet, + Buffalo_L ) from deepface.models.face_detection import ( FastMtCnn, @@ -59,7 +60,8 @@ def build_model(task: str, model_name: str) -> Any: "Dlib": Dlib.DlibClient, "ArcFace": ArcFace.ArcFaceClient, "SFace": SFace.SFaceClient, - "GhostFaceNet": GhostFaceNet.GhostFaceNetClient + "GhostFaceNet": GhostFaceNet.GhostFaceNetClient, + "Buffalo_L": Buffalo_L.Buffalo_L }, "spoofing": { "Fasnet": FasNet.Fasnet, diff --git a/deepface/modules/verification.py b/deepface/modules/verification.py index 5ddf10f..35d2f25 100644 --- a/deepface/modules/verification.py +++ b/deepface/modules/verification.py @@ -423,6 +423,7 @@ def find_threshold(model_name: str, distance_metric: str) -> float: "DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64}, "DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17}, "GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10}, + "Buffalo_L": {"cosine": 0.65, "euclidean": 11.13, "euclidean_l2": 1.1}, } threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4) diff --git a/requirements.txt b/requirements.txt index c208ada..afb1518 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ flask_cors>=4.0.1 mtcnn>=0.1.0 retina-face>=0.0.14 fire>=0.4.0 -gunicorn>=20.1.0 +gunicorn>=20.1.0 \ No newline at end of file diff --git a/requirements_additional.txt b/requirements_additional.txt index ea76fde..cc38a19 100644 --- a/requirements_additional.txt +++ b/requirements_additional.txt @@ -3,4 +3,10 @@ mediapipe>=0.8.7.3 dlib>=19.20.0 ultralytics>=8.0.122 facenet-pytorch>=2.5.3 -torch>=2.1.2 \ No newline at end of file +torch>=2.1.2 +insightface>=0.7.3 +onnxruntime>=1.9.0 +tf-keras +typing-extensions +pydantic +albumentations \ No newline at end of file