Merge pull request #1423 from PyWoody/load_image_io

load_image now accepts file objects that support being read
This commit is contained in:
Sefik Ilkin Serengil 2025-01-10 17:16:28 +00:00 committed by GitHub
commit 29200f4fd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 92 additions and 30 deletions

View File

@ -2,7 +2,7 @@
import os import os
import warnings import warnings
import logging import logging
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, IO, List, Union, Optional
# this has to be set before importing tensorflow # this has to be set before importing tensorflow
os.environ["TF_USE_LEGACY_KERAS"] = "1" os.environ["TF_USE_LEGACY_KERAS"] = "1"
@ -68,8 +68,8 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any:
def verify( def verify(
img1_path: Union[str, np.ndarray, List[float]], img1_path: Union[str, np.ndarray, IO[bytes], List[float]],
img2_path: Union[str, np.ndarray, List[float]], img2_path: Union[str, np.ndarray, IO[bytes], List[float]],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
detector_backend: str = "opencv", detector_backend: str = "opencv",
distance_metric: str = "cosine", distance_metric: str = "cosine",
@ -84,12 +84,14 @@ def verify(
""" """
Verify if an image pair represents the same person or different persons. Verify if an image pair represents the same person or different persons.
Args: Args:
img1_path (str or np.ndarray or List[float]): Path to the first image. img1_path (str or np.ndarray or IO[bytes] or List[float]): Path to the first image.
Accepts exact image path as a string, numpy array (BGR), base64 encoded images Accepts exact image path as a string, numpy array (BGR), a file object that supports
at least `.read` and is opened in binary mode, base64 encoded images
or pre-calculated embeddings. or pre-calculated embeddings.
img2_path (str or np.ndarray or List[float]): Path to the second image. img2_path (str or np.ndarray or IO[bytes] or List[float]): Path to the second image.
Accepts exact image path as a string, numpy array (BGR), base64 encoded images Accepts exact image path as a string, numpy array (BGR), a file object that supports
at least `.read` and is opened in binary mode, base64 encoded images
or pre-calculated embeddings. or pre-calculated embeddings.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
@ -164,7 +166,7 @@ def verify(
def analyze( def analyze(
img_path: Union[str, np.ndarray], img_path: Union[str, np.ndarray, IO[bytes]],
actions: Union[tuple, list] = ("emotion", "age", "gender", "race"), actions: Union[tuple, list] = ("emotion", "age", "gender", "race"),
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
@ -176,9 +178,10 @@ def analyze(
""" """
Analyze facial attributes such as age, gender, emotion, and race in the provided image. Analyze facial attributes such as age, gender, emotion, and race in the provided image.
Args: Args:
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, img_path (str or np.ndarray or IO[bytes]): The exact path to the image, a numpy array
or a base64 encoded image. If the source image contains multiple faces, the result will in BGR format, a file object that supports at least `.read` and is opened in binary
include information for each detected face. mode, or a base64 encoded image. If the source image contains multiple faces,
the result will include information for each detected face.
actions (tuple): Attributes to analyze. The default is ('age', 'gender', 'emotion', 'race'). actions (tuple): Attributes to analyze. The default is ('age', 'gender', 'emotion', 'race').
You can exclude some of these attributes from the analysis if needed. You can exclude some of these attributes from the analysis if needed.
@ -263,7 +266,7 @@ def analyze(
def find( def find(
img_path: Union[str, np.ndarray], img_path: Union[str, np.ndarray, IO[bytes]],
db_path: str, db_path: str,
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
distance_metric: str = "cosine", distance_metric: str = "cosine",
@ -281,9 +284,10 @@ def find(
""" """
Identify individuals in a database Identify individuals in a database
Args: Args:
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, img_path (str or np.ndarray or IO[bytes]): The exact path to the image, a numpy array
or a base64 encoded image. If the source image contains multiple faces, the result will in BGR format, a file object that supports at least `.read` and is opened in binary
include information for each detected face. mode, or a base64 encoded image. If the source image contains multiple
faces, the result will include information for each detected face.
db_path (string): Path to the folder containing image files. All detected faces db_path (string): Path to the folder containing image files. All detected faces
in the database will be considered in the decision-making process. in the database will be considered in the decision-making process.
@ -369,7 +373,7 @@ def find(
def represent( def represent(
img_path: Union[str, np.ndarray], img_path: Union[str, np.ndarray, IO[bytes]],
model_name: str = "VGG-Face", model_name: str = "VGG-Face",
enforce_detection: bool = True, enforce_detection: bool = True,
detector_backend: str = "opencv", detector_backend: str = "opencv",
@ -383,9 +387,10 @@ def represent(
Represent facial images as multi-dimensional vector embeddings. Represent facial images as multi-dimensional vector embeddings.
Args: Args:
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format, img_path (str or np.ndarray or IO[bytes]): The exact path to the image, a numpy array
or a base64 encoded image. If the source image contains multiple faces, the result will in BGR format, a file object that supports at least `.read` and is opened in binary
include information for each detected face. mode, or a base64 encoded image. If the source image contains multiple faces,
the result will include information for each detected face.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
@ -505,7 +510,7 @@ def stream(
def extract_faces( def extract_faces(
img_path: Union[str, np.ndarray], img_path: Union[str, np.ndarray, IO[bytes]],
detector_backend: str = "opencv", detector_backend: str = "opencv",
enforce_detection: bool = True, enforce_detection: bool = True,
align: bool = True, align: bool = True,
@ -519,8 +524,9 @@ def extract_faces(
Extract faces from a given image Extract faces from a given image
Args: Args:
img_path (str or np.ndarray): Path to the first image. Accepts exact image path img_path (str or np.ndarray or IO[bytes]): Path to the first image. Accepts exact image path
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), a file object that supports at least `.read` and is
opened in binary mode, or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',

View File

@ -1,7 +1,7 @@
# built-in dependencies # built-in dependencies
import os import os
import io import io
from typing import Generator, List, Union, Tuple from typing import Generator, IO, List, Union, Tuple
import hashlib import hashlib
import base64 import base64
from pathlib import Path from pathlib import Path
@ -77,11 +77,11 @@ def find_image_hash(file_path: str) -> str:
return hasher.hexdigest() return hasher.hexdigest()
def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]: def load_image(img: Union[str, np.ndarray, IO[bytes]]) -> Tuple[np.ndarray, str]:
""" """
Load image from path, url, base64 or numpy array. Load image from path, url, file object, base64 or numpy array.
Args: Args:
img: a path, url, base64 or numpy array. img: a path, url, file object, base64 or numpy array.
Returns: Returns:
image (numpy array): the loaded image in BGR format image (numpy array): the loaded image in BGR format
image name (str): image name itself image name (str): image name itself
@ -91,6 +91,14 @@ def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
if isinstance(img, np.ndarray): if isinstance(img, np.ndarray):
return img, "numpy array" return img, "numpy array"
# The image is an object that supports `.read`
if hasattr(img, 'read') and callable(img.read):
if isinstance(img, io.StringIO):
raise ValueError(
'img requires bytes and cannot be an io.StringIO object.'
)
return load_image_from_io_object(img), 'io object'
if isinstance(img, Path): if isinstance(img, Path):
img = str(img) img = str(img)
@ -120,6 +128,32 @@ def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
return img_obj_bgr, img return img_obj_bgr, img
def load_image_from_io_object(obj: IO[bytes]) -> np.ndarray:
"""
Load image from an object that supports being read
Args:
obj: a file like object.
Returns:
img (np.ndarray): The decoded image as a numpy array (OpenCV format).
"""
try:
_ = obj.seek(0)
except (AttributeError, TypeError, io.UnsupportedOperation):
seekable = False
obj = io.BytesIO(obj.read())
else:
seekable = True
try:
nparr = np.frombuffer(obj.read(), np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("Failed to decode image")
return img
finally:
if not seekable:
obj.close()
def load_image_from_base64(uri: str) -> np.ndarray: def load_image_from_base64(uri: str) -> np.ndarray:
""" """
Load image from base64 string. Load image from base64 string.

View File

@ -1,5 +1,5 @@
# built-in dependencies # built-in dependencies
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Dict, IO, List, Tuple, Union, Optional
# 3rd part dependencies # 3rd part dependencies
from heapq import nlargest from heapq import nlargest
@ -19,7 +19,7 @@ logger = Logger()
def extract_faces( def extract_faces(
img_path: Union[str, np.ndarray], img_path: Union[str, np.ndarray, IO[bytes]],
detector_backend: str = "opencv", detector_backend: str = "opencv",
enforce_detection: bool = True, enforce_detection: bool = True,
align: bool = True, align: bool = True,
@ -34,8 +34,9 @@ def extract_faces(
Extract faces from a given image Extract faces from a given image
Args: Args:
img_path (str or np.ndarray): Path to the first image. Accepts exact image path img_path (str or np.ndarray or IO[bytes]): Path to the first image. Accepts exact image path
as a string, numpy array (BGR), or base64 encoded images. as a string, numpy array (BGR), a file object that supports at least `.read` and is
opened in binary mode, or base64 encoded images.
detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', detector_backend (string): face detector backend. Options: 'opencv', 'retinaface',
'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m',

View File

@ -1,5 +1,7 @@
# built-in dependencies # built-in dependencies
import io
import cv2 import cv2
import pytest
# project dependencies # project dependencies
from deepface import DeepFace from deepface import DeepFace
@ -18,6 +20,25 @@ def test_standard_represent():
logger.info("✅ test standard represent function done") logger.info("✅ test standard represent function done")
def test_standard_represent_with_io_object():
img_path = "dataset/img1.jpg"
default_embedding_objs = DeepFace.represent(img_path)
io_embedding_objs = DeepFace.represent(open(img_path, 'rb'))
assert default_embedding_objs == io_embedding_objs
# Confirm non-seekable io objects are handled properly
io_obj = io.BytesIO(open(img_path, 'rb').read())
io_obj.seek = None
no_seek_io_embedding_objs = DeepFace.represent(io_obj)
assert default_embedding_objs == no_seek_io_embedding_objs
# Confirm non-image io objects raise exceptions
with pytest.raises(ValueError, match='Failed to decode image'):
DeepFace.represent(io.BytesIO(open(r'../requirements.txt', 'rb').read()))
logger.info("✅ test standard represent with io object function done")
def test_represent_for_skipped_detector_backend_with_image_path(): def test_represent_for_skipped_detector_backend_with_image_path():
face_img = "dataset/img5.jpg" face_img = "dataset/img5.jpg"
img_objs = DeepFace.represent(img_path=face_img, detector_backend="skip") img_objs = DeepFace.represent(img_path=face_img, detector_backend="skip")