mirror of
https://github.com/serengil/deepface.git
synced 2025-06-06 11:35:21 +00:00
added reviewed changes
This commit is contained in:
parent
89b2d49eb5
commit
e19c7fcc1c
@ -54,7 +54,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
|
||||
Args:
|
||||
model_name (str): model identifier
|
||||
- VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
|
||||
ArcFace, SFace GhostFaceNet and buffalo_l for face recognition
|
||||
ArcFace, SFace GhostFaceNet and Buffalo_L for face recognition
|
||||
- Age, Gender, Emotion, Race for facial attributes
|
||||
- opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n,
|
||||
yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors
|
||||
|
@ -1,73 +1,73 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from deepface.models.FacialRecognition import FacialRecognition
|
||||
from deepface.commons.logger import Logger
|
||||
from deepface.basemodel import get_weights_path
|
||||
from deepface.common import weight_utils
|
||||
import os
|
||||
import numpy as np
|
||||
from deepface.commons import weight_utils
|
||||
from deepface.commons.logger import Logger
|
||||
from deepface.models.FacialRecognition import FacialRecognition
|
||||
|
||||
logger = Logger()
|
||||
|
||||
# Check for insightface dependency
|
||||
try:
|
||||
from insightface.model_zoo import get_model
|
||||
except ModuleNotFoundError:
|
||||
except ModuleNotFoundError as err:
|
||||
raise ModuleNotFoundError(
|
||||
"InsightFace is an optional dependency for the Buffalo_L model."
|
||||
"You can install it with: pip install insightface>=0.7.3"
|
||||
)
|
||||
) from err
|
||||
|
||||
class Buffalo_L(FacialRecognition):
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.input_shape = (112, 112) # Buffalo_L recognition model expects 112x112
|
||||
self.output_shape = 512 # Embedding size for Buffalo_L
|
||||
self.input_shape = (112, 112) # Buffalo_L expects 112x112
|
||||
self.output_shape = 512 # Embedding size
|
||||
self.load_model()
|
||||
|
||||
def load_model(self):
|
||||
root = os.path.join(get_weights_path(), 'insightface')
|
||||
model_name = 'buffalo_l/w600k_r50.onnx'
|
||||
model_path = os.path.join(root, model_name)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
url = 'https://drive.google.com/file/d/1N0GL-8ehw_bz2eZQWz2b0A5XBdXdxZhg/view?usp=sharing'
|
||||
weight_utils.download_file(url, model_path)
|
||||
|
||||
self.model = get_model(model_name, root=root)
|
||||
"""
|
||||
Load the InsightFace Buffalo_L recognition model.
|
||||
"""
|
||||
# Use DeepFace's utility to download weights if necessary
|
||||
model_rel_path = os.path.join("insightface", "buffalo_l", "w600k_r50.onnx")
|
||||
weights_path = weight_utils.download_weights_if_necessary(
|
||||
file_name="webface_r50.onnx",
|
||||
source_url="https://drive.google.com/uc?export=download&confirm=pbef&id=1N0GL-8ehw_bz2eZQWz2b0A5XBdXdxZhg"
|
||||
)
|
||||
# Load model from weights folder
|
||||
self.model = get_model("buffalo_l/w600k_r50.onnx", root=os.path.dirname(weights_path))
|
||||
self.model.prepare(ctx_id=-1, input_size=self.input_shape)
|
||||
|
||||
def preprocess(self, img):
|
||||
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Preprocess the image to match InsightFace recognition model expectations.
|
||||
Args:
|
||||
img (numpy array): Image in shape (1, 112, 112, 3) or (112, 112, 3)
|
||||
img: Image in shape (1, 112, 112, 3) or (112, 112, 3)
|
||||
Returns:
|
||||
numpy array: Preprocessed image
|
||||
Preprocessed image as numpy array
|
||||
"""
|
||||
if len(img.shape) == 4: # (1, 112, 112, 3)
|
||||
img = img[0] # Remove batch dimension
|
||||
if img.max() <= 1.0: # If normalized to [0, 1]
|
||||
img = (img * 255).astype(np.uint8)
|
||||
# Always convert RGB to BGR (InsightFace expects BGR, DeepFace provides RGB)
|
||||
# Always convert RGB to BGR (DeepFace outputs RGB, InsightFace expects BGR)
|
||||
img = img[:, :, ::-1]
|
||||
return img
|
||||
|
||||
def forward(self, img):
|
||||
def forward(self, img: np.ndarray) -> list[float]:
|
||||
"""
|
||||
Extract face embedding from a pre-cropped face image.
|
||||
Args:
|
||||
img (numpy array): Preprocessed face image with shape (1, 112, 112, 3)
|
||||
img: Preprocessed face image with shape (1, 112, 112, 3)
|
||||
Returns:
|
||||
numpy array: Face embedding
|
||||
Face embedding as a list of floats
|
||||
"""
|
||||
img = self.preprocess(img)
|
||||
embedding = self.model.get_feat(img)
|
||||
|
||||
# Handle different embedding formats
|
||||
if isinstance(embedding, np.ndarray):
|
||||
if len(embedding.shape) > 1:
|
||||
embedding = embedding.flatten()
|
||||
if isinstance(embedding, np.ndarray) and len(embedding.shape) > 1:
|
||||
embedding = embedding.flatten()
|
||||
elif isinstance(embedding, list):
|
||||
embedding = np.array(embedding).flatten()
|
||||
|
||||
return embedding
|
||||
else:
|
||||
raise ValueError(f"Unexpected embedding type: {type(embedding)}")
|
||||
|
||||
return embedding.tolist() # Convert to list per FacialRecognition spec
|
Loading…
x
Reference in New Issue
Block a user