code review changes

This commit is contained in:
Raghucharan16 2025-02-20 23:08:09 +05:30
parent 41ae9bbcf3
commit 5a73d91744
4 changed files with 24 additions and 34 deletions

View File

@ -1,11 +1,19 @@
import cv2
import numpy as np
from insightface.model_zoo import get_model
from deepface.models.FacialRecognition import FacialRecognition
from deepface.commons.logger import Logger
logger = Logger()
# Check for insightface dependency
try:
from insightface.model_zoo import get_model
except ModuleNotFoundError:
raise ModuleNotFoundError(
"InsightFace is an optional dependency for the Buffalo_L model."
"You can install it with: pip install insightface>=0.7.3"
)
class Buffalo_L(FacialRecognition):
def __init__(self):
self.model = None
@ -17,8 +25,6 @@ class Buffalo_L(FacialRecognition):
"""
Load the InsightFace Buffalo_L recognition model.
"""
# Load the recognition model directly (e.g., w600k_r50 from buffalo_l)
# The buffalo_l package includes recognition model weights
self.model = get_model('buffalo_l/w600k_r50.onnx', download=True)
self.model.prepare(ctx_id=-1, input_size=self.input_shape) # ctx_id=-1 for CPU
@ -32,12 +38,10 @@ class Buffalo_L(FacialRecognition):
"""
if len(img.shape) == 4: # (1, 112, 112, 3)
img = img[0] # Remove batch dimension
# Ensure image is in uint8 format (0-255 range)
if img.max() <= 1.0: # If normalized to [0, 1]
img = (img * 255).astype(np.uint8)
# Convert to BGR if in RGB (InsightFace expects BGR)
if img.shape[2] == 3:
img = img[:, :, ::-1] # RGB to BGR
# Always convert RGB to BGR (InsightFace expects BGR, DeepFace provides RGB)
img = img[:, :, ::-1]
return img
def forward(self, img):
@ -48,29 +52,14 @@ class Buffalo_L(FacialRecognition):
Returns:
numpy array: Face embedding
"""
# Preprocess the input image
img = self.preprocess(img)
# Extract embedding directly (no detection needed)
embedding = self.model.get_feat(img)
# InsightFace recognition models return a list or array; ensure 1D output
if isinstance(embedding, (list, np.ndarray)) and len(embedding.shape) > 1:
embedding = embedding.flatten()
# Handle different embedding formats
if isinstance(embedding, np.ndarray):
if len(embedding.shape) > 1:
embedding = embedding.flatten()
elif isinstance(embedding, list):
embedding = np.array(embedding).flatten()
return embedding
def verify(self, img1, img2, threshold=0.65):
"""
Verify if two images contain the same person using cosine similarity.
Args:
img1, img2 (numpy arrays): Preprocessed images
threshold (float): Cosine similarity threshold
Returns:
tuple: (similarity_score, is_same_person)
"""
emb1 = self.forward(img1)
emb2 = self.forward(img2)
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
return similarity, similarity > threshold
return embedding

View File

@ -422,6 +422,7 @@ def find_threshold(model_name: str, distance_metric: str) -> float:
"DeepFace": {"cosine": 0.23, "euclidean": 64, "euclidean_l2": 0.64},
"DeepID": {"cosine": 0.015, "euclidean": 45, "euclidean_l2": 0.17},
"GhostFaceNet": {"cosine": 0.65, "euclidean": 35.71, "euclidean_l2": 1.10},
"Buffalo_L": {"cosine": 0.65, "euclidean": 11.13, "euclidean_l2": 1.1},
}
threshold = thresholds.get(model_name, base_threshold).get(distance_metric, 0.4)

View File

@ -12,7 +12,4 @@ flask_cors>=4.0.1
mtcnn>=0.1.0
retina-face>=0.0.14
fire>=0.4.0
gunicorn>=20.1.0
tf-keras
insightface==0.7.3
onnxruntime>=1.9.0
gunicorn>=20.1.0

View File

@ -3,4 +3,7 @@ mediapipe>=0.8.7.3
dlib>=19.20.0
ultralytics>=8.0.122
facenet-pytorch>=2.5.3
torch>=2.1.2
torch>=2.1.2
insightface>=0.7.3
onnxruntime>=1.9.0
tf-keras