batched inputs in representation

This commit is contained in:
galthran-wq 2025-02-11 13:01:29 +00:00
parent 72e82f0605
commit 0ef420bc10
2 changed files with 88 additions and 60 deletions

View File

@ -18,7 +18,7 @@ class FacialRecognition(ABC):
input_shape: Tuple[int, int] input_shape: Tuple[int, int]
output_shape: int output_shape: int
def forward(self, img: np.ndarray) -> List[float]: def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
if not isinstance(self.model, Model): if not isinstance(self.model, Model):
raise ValueError( raise ValueError(
"You must overwrite forward method if it is not a keras model," "You must overwrite forward method if it is not a keras model,"
@ -26,4 +26,10 @@ class FacialRecognition(ABC):
) )
# model.predict causes memory issue when it is called in a for loop # model.predict causes memory issue when it is called in a for loop
# embedding = model.predict(img, verbose=0)[0].tolist() # embedding = model.predict(img, verbose=0)[0].tolist()
return self.model(img, training=False).numpy()[0].tolist() if img.shape == 4 and img.shape[0] == 1:
img = img[0]
embeddings = self.model(img, training=False).numpy()
if embeddings.shape[0] == 1:
return embeddings[0].tolist()
else:
return embeddings.tolist()

View File

@ -11,7 +11,7 @@ from deepface.models.FacialRecognition import FacialRecognition
def represent( def represent(
img_path: Union[str, np.ndarray], img_path: Union[str, np.ndarray, List[Union[str, np.ndarray]]],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
@ -25,9 +25,9 @@ def represent(
Represent facial images as multi-dimensional vector embeddings. Represent facial images as multi-dimensional vector embeddings.
Args: Args:
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, img_path (str, np.ndarray, or list): The exact path to the image, a numpy array in BGR format,
or a base64 encoded image. If the source image contains multiple faces, the result will a base64 encoded image, or a list of these. If the source image contains multiple faces,
include information for each detected face. the result will include information for each detected face.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
@ -70,70 +70,92 @@ def represent(
task="facial_recognition", model_name=model_name task="facial_recognition", model_name=model_name
) )
# --------------------------------- # Handle list of image paths or 4D numpy array
# we have run pre-process in verification. so, this can be skipped if it is coming from verify. if isinstance(img_path, list):
target_size = model.input_shape images = img_path
if detector_backend != "skip": elif isinstance(img_path, np.ndarray) and img_path.ndim == 4:
img_objs = detection.extract_faces( images = [img_path[i] for i in range(img_path.shape[0])]
img_path=img_path, else:
detector_backend=detector_backend, images = [img_path]
grayscale=False,
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
anti_spoofing=anti_spoofing,
max_faces=max_faces,
)
else: # skip
# Try load. If load error, will raise exception internal
img, _ = image_utils.load_image(img_path)
if len(img.shape) != 3: batch_images = []
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}") batch_regions = []
batch_confidences = []
# make dummy region and confidence to keep compatibility with `extract_faces` for single_img_path in images:
img_objs = [ # ---------------------------------
{ # we have run pre-process in verification. so, this can be skipped if it is coming from verify.
"face": img, target_size = model.input_shape
"facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]}, if detector_backend != "skip":
"confidence": 0, img_objs = detection.extract_faces(
} img_path=single_img_path,
] detector_backend=detector_backend,
# --------------------------------- grayscale=False,
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
anti_spoofing=anti_spoofing,
max_faces=max_faces,
)
else: # skip
# Try load. If load error, will raise exception internal
img, _ = image_utils.load_image(single_img_path)
if max_faces is not None and max_faces < len(img_objs): if len(img.shape) != 3:
# sort as largest facial areas come first raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")
img_objs = sorted(
img_objs,
key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"],
reverse=True,
)
# discard rest of the items
img_objs = img_objs[0:max_faces]
for img_obj in img_objs: # make dummy region and confidence to keep compatibility with `extract_faces`
if anti_spoofing is True and img_obj.get("is_real", True) is False: img_objs = [
raise ValueError("Spoof detected in the given image.") {
img = img_obj["face"] "face": img,
"facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]},
"confidence": 0,
}
]
# ---------------------------------
# bgr to rgb if max_faces is not None and max_faces < len(img_objs):
img = img[:, :, ::-1] # sort as largest facial areas come first
img_objs = sorted(
img_objs,
key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"],
reverse=True,
)
# discard rest of the items
img_objs = img_objs[0:max_faces]
region = img_obj["facial_area"] for img_obj in img_objs:
confidence = img_obj["confidence"] if anti_spoofing is True and img_obj.get("is_real", True) is False:
raise ValueError("Spoof detected in the given image.")
img = img_obj["face"]
# resize to expected shape of ml model # bgr to rgb
img = preprocessing.resize_image( img = img[:, :, ::-1]
img=img,
# thanks to DeepId (!)
target_size=(target_size[1], target_size[0]),
)
# custom normalization region = img_obj["facial_area"]
img = preprocessing.normalize_input(img=img, normalization=normalization) confidence = img_obj["confidence"]
embedding = model.forward(img) # resize to expected shape of ml model
img = preprocessing.resize_image(
img=img,
# thanks to DeepId (!)
target_size=(target_size[1], target_size[0]),
)
# custom normalization
img = preprocessing.normalize_input(img=img, normalization=normalization)
batch_images.append(img)
batch_regions.append(region)
batch_confidences.append(confidence)
# Convert list of images to a numpy array for batch processing
batch_images = np.concat(batch_images)
# Forward pass through the model for the entire batch
embeddings = model.forward(batch_images)
for embedding, region, confidence in zip(embeddings, batch_regions, batch_confidences):
resp_objs.append( resp_objs.append(
{ {
"embedding": embedding, "embedding": embedding,