mirror of
https://github.com/serengil/deepface.git
synced 2025-06-04 02:20:06 +00:00
code review changes
This commit is contained in:
parent
41ae9bbcf3
commit
5a73d91744
@ -1,11 +1,19 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from insightface.model_zoo import get_model
|
|
||||||
from deepface.models.FacialRecognition import FacialRecognition
|
from deepface.models.FacialRecognition import FacialRecognition
|
||||||
from deepface.commons.logger import Logger
|
from deepface.commons.logger import Logger
|
||||||
|
|
||||||
logger = Logger()
|
logger = Logger()
|
||||||
|
|
||||||
|
# Check for insightface dependency
|
||||||
|
try:
|
||||||
|
from insightface.model_zoo import get_model
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"InsightFace is an optional dependency for the Buffalo_L model."
|
||||||
|
"You can install it with: pip install insightface>=0.7.3"
|
||||||
|
)
|
||||||
|
|
||||||
class Buffalo_L(FacialRecognition):
|
class Buffalo_L(FacialRecognition):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model = None
|
self.model = None
|
||||||
@ -17,8 +25,6 @@ class Buffalo_L(FacialRecognition):
|
|||||||
"""
|
"""
|
||||||
Load the InsightFace Buffalo_L recognition model.
|
Load the InsightFace Buffalo_L recognition model.
|
||||||
"""
|
"""
|
||||||
# Load the recognition model directly (e.g., w600k_r50 from buffalo_l)
|
|
||||||
# The buffalo_l package includes recognition model weights
|
|
||||||
self.model = get_model('buffalo_l/w600k_r50.onnx', download=True)
|
self.model = get_model('buffalo_l/w600k_r50.onnx', download=True)
|
||||||
self.model.prepare(ctx_id=-1, input_size=self.input_shape) # ctx_id=-1 for CPU
|
self.model.prepare(ctx_id=-1, input_size=self.input_shape) # ctx_id=-1 for CPU
|
||||||
|
|
||||||
@ -32,12 +38,10 @@ class Buffalo_L(FacialRecognition):
|
|||||||
"""
|
"""
|
||||||
if len(img.shape) == 4: # (1, 112, 112, 3)
|
if len(img.shape) == 4: # (1, 112, 112, 3)
|
||||||
img = img[0] # Remove batch dimension
|
img = img[0] # Remove batch dimension
|
||||||
# Ensure image is in uint8 format (0-255 range)
|
|
||||||
if img.max() <= 1.0: # If normalized to [0, 1]
|
if img.max() <= 1.0: # If normalized to [0, 1]
|
||||||
img = (img * 255).astype(np.uint8)
|
img = (img * 255).astype(np.uint8)
|
||||||
# Convert to BGR if in RGB (InsightFace expects BGR)
|
# Always convert RGB to BGR (InsightFace expects BGR, DeepFace provides RGB)
|
||||||
if img.shape[2] == 3:
|
img = img[:, :, ::-1]
|
||||||
img = img[:, :, ::-1] # RGB to BGR
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, img):
|
def forward(self, img):
|
||||||
@ -48,29 +52,14 @@ class Buffalo_L(FacialRecognition):
|
|||||||
Returns:
|
Returns:
|
||||||
numpy array: Face embedding
|
numpy array: Face embedding
|
||||||
"""
|
"""
|
||||||
# Preprocess the input image
|
|
||||||
img = self.preprocess(img)
|
img = self.preprocess(img)
|
||||||
|
|
||||||
# Extract embedding directly (no detection needed)
|
|
||||||
embedding = self.model.get_feat(img)
|
embedding = self.model.get_feat(img)
|
||||||
|
|
||||||
# InsightFace recognition models return a list or array; ensure 1D output
|
# Handle different embedding formats
|
||||||
if isinstance(embedding, (list, np.ndarray)) and len(embedding.shape) > 1:
|
if isinstance(embedding, np.ndarray):
|
||||||
embedding = embedding.flatten()
|
if len(embedding.shape) > 1:
|
||||||
|
embedding = embedding.flatten()
|
||||||
|
elif isinstance(embedding, list):
|
||||||
|
embedding = np.array(embedding).flatten()
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def verify(self, img1, img2, threshold=0.65):
|
|
||||||
"""
|
|
||||||
Verify if two images contain the same person using cosine similarity.
|
|
||||||
Args:
|
|
||||||
img1, img2 (numpy arrays): Preprocessed images
|
|
||||||
threshold (float): Cosine similarity threshold
|
|
||||||
Returns:
|
|
||||||
tuple: (similarity_score, is_same_person)
|
|
||||||
"""
|
|
||||||
emb1 = self.forward(img1)
|
|
||||||
emb2 = self.forward(img2)
|
|
||||||
|
|
||||||
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
|
|
||||||
return similarity, similarity > threshold
|
|
@ -422,6 +422,7 @@ def find_threshold(model_name: str, distance_metric: str) -> float:
|
|||||||
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64},
|
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64},
|
||||||
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17},
|
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17},
|
||||||
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10},
|
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10},
|
||||||
|
"Buffalo_L": {"cosine": 0.65, "euclidean": 11.13, "euclidean_l2": 1.1},
|
||||||
}
|
}
|
||||||
|
|
||||||
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)
|
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)
|
||||||
|
@ -12,7 +12,4 @@ flask_cors>=4.0.1
|
|||||||
mtcnn>=0.1.0
|
mtcnn>=0.1.0
|
||||||
retina-face>=0.0.14
|
retina-face>=0.0.14
|
||||||
fire>=0.4.0
|
fire>=0.4.0
|
||||||
gunicorn>=20.1.0
|
gunicorn>=20.1.0
|
||||||
tf-keras
|
|
||||||
insightface==0.7.3
|
|
||||||
onnxruntime>=1.9.0
|
|
@ -3,4 +3,7 @@ mediapipe>=0.8.7.3
|
|||||||
dlib>=19.20.0
|
dlib>=19.20.0
|
||||||
ultralytics>=8.0.122
|
ultralytics>=8.0.122
|
||||||
facenet-pytorch>=2.5.3
|
facenet-pytorch>=2.5.3
|
||||||
torch>=2.1.2
|
torch>=2.1.2
|
||||||
|
insightface>=0.7.3
|
||||||
|
onnxruntime>=1.9.0
|
||||||
|
tf-keras
|
Loading…
x
Reference in New Issue
Block a user