diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index ae0fb0b..99bcb64 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -54,7 +54,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: Args: model_name (str): model identifier - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - ArcFace, SFace and GhostFaceNet for face recognition + ArcFace, SFace GhostFaceNet and buffalo_l for face recognition - Age, Gender, Emotion, Race for facial attributes - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n, yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors diff --git a/deepface/models/facial_recognition/Buffalo_L.py b/deepface/models/facial_recognition/Buffalo_L.py new file mode 100644 index 0000000..0ec2365 --- /dev/null +++ b/deepface/models/facial_recognition/Buffalo_L.py @@ -0,0 +1,76 @@ +import cv2 +import numpy as np +from insightface.model_zoo import get_model +from deepface.models.FacialRecognition import FacialRecognition +from deepface.commons.logger import Logger + +logger = Logger() + +class Buffalo_L(FacialRecognition): + def __init__(self): + self.model = None + self.input_shape = (112, 112) # Buffalo_L recognition model expects 112x112 + self.output_shape = 512 # Embedding size for Buffalo_L + self.load_model() + + def load_model(self): + """ + Load the InsightFace Buffalo_L recognition model. + """ + # Load the recognition model directly (e.g., w600k_r50 from buffalo_l) + # The buffalo_l package includes recognition model weights + self.model = get_model('buffalo_l/w600k_r50.onnx', download=True) + self.model.prepare(ctx_id=-1, input_size=self.input_shape) # ctx_id=-1 for CPU + + def preprocess(self, img): + """ + Preprocess the image to match InsightFace recognition model expectations. + Args: + img (numpy array): Image in shape (1, 112, 112, 3) or (112, 112, 3) + Returns: + numpy array: Preprocessed image + """ + if len(img.shape) == 4: # (1, 112, 112, 3) + img = img[0] # Remove batch dimension + # Ensure image is in uint8 format (0-255 range) + if img.max() <= 1.0: # If normalized to [0, 1] + img = (img * 255).astype(np.uint8) + # Convert to BGR if in RGB (InsightFace expects BGR) + if img.shape[2] == 3: + img = img[:, :, ::-1] # RGB to BGR + return img + + def forward(self, img): + """ + Extract face embedding from a pre-cropped face image. + Args: + img (numpy array): Preprocessed face image with shape (1, 112, 112, 3) + Returns: + numpy array: Face embedding + """ + # Preprocess the input image + img = self.preprocess(img) + + # Extract embedding directly (no detection needed) + embedding = self.model.get_feat(img) + + # InsightFace recognition models return a list or array; ensure 1D output + if isinstance(embedding, (list, np.ndarray)) and len(embedding.shape) > 1: + embedding = embedding.flatten() + + return embedding + + def verify(self, img1, img2, threshold=0.65): + """ + Verify if two images contain the same person using cosine similarity. + Args: + img1, img2 (numpy arrays): Preprocessed images + threshold (float): Cosine similarity threshold + Returns: + tuple: (similarity_score, is_same_person) + """ + emb1 = self.forward(img1) + emb2 = self.forward(img2) + + similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) + return similarity, similarity > threshold \ No newline at end of file diff --git a/deepface/modules/modeling.py b/deepface/modules/modeling.py index 176d9e7..7ec067d 100644 --- a/deepface/modules/modeling.py +++ b/deepface/modules/modeling.py @@ -11,7 +11,8 @@ from deepface.models.facial_recognition import ( SFace, Dlib, Facenet, - GhostFaceNet + GhostFaceNet, + Buffalo_L ) from deepface.models.face_detection import ( FastMtCnn, @@ -59,7 +60,8 @@ def build_model(task: str, model_name: str) -> Any: "Dlib": Dlib.DlibClient, "ArcFace": ArcFace.ArcFaceClient, "SFace": SFace.SFaceClient, - "GhostFaceNet": GhostFaceNet.GhostFaceNetClient + "GhostFaceNet": GhostFaceNet.GhostFaceNetClient, + "Buffalo_L": Buffalo_L.Buffalo_L }, "spoofing": { "Fasnet": FasNet.Fasnet, diff --git a/requirements.txt b/requirements.txt index c208ada..fc28fd7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,6 @@ mtcnn>=0.1.0 retina-face>=0.0.14 fire>=0.4.0 gunicorn>=20.1.0 +tf-keras +insightface==0.7.3 +onnxruntime>=1.9.0 \ No newline at end of file