mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 12:05:22 +00:00
160 lines
4.7 KiB
Python
160 lines
4.7 KiB
Python
# built-in dependencies
|
||
import os
|
||
from typing import Union, Tuple
|
||
import base64
|
||
from pathlib import Path
|
||
import io
|
||
|
||
# 3rd party
|
||
import numpy as np
|
||
import cv2
|
||
import requests
|
||
from PIL import Image
|
||
|
||
|
||
def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
|
||
"""
|
||
Load image from path, url, base64 or numpy array.
|
||
Args:
|
||
img: a path, url, base64 or numpy array.
|
||
Returns:
|
||
image (numpy array): the loaded image in BGR format
|
||
image name (str): image name itself
|
||
"""
|
||
|
||
# The image is already a numpy array
|
||
if isinstance(img, np.ndarray):
|
||
return img, "numpy array"
|
||
|
||
if isinstance(img, Path):
|
||
img = str(img)
|
||
|
||
if not isinstance(img, str):
|
||
raise ValueError(f"img must be numpy array or str but it is {type(img)}")
|
||
|
||
# The image is a base64 string
|
||
if img.startswith("data:image/"):
|
||
return load_base64(img), "base64 encoded string"
|
||
|
||
# The image is a url
|
||
if img.lower().startswith("http://") or img.lower().startswith("https://"):
|
||
return load_image_from_web(url=img), img
|
||
|
||
# The image is a path
|
||
if os.path.isfile(img) is not True:
|
||
raise ValueError(f"Confirm that {img} exists")
|
||
|
||
# image must be a file on the system then
|
||
|
||
# image name must have english characters
|
||
if img.isascii() is False:
|
||
raise ValueError(f"Input image must not have non-english characters - {img}")
|
||
|
||
img_obj_bgr = cv2.imread(img)
|
||
# img_obj_rgb = cv2.cvtColor(img_obj_bgr, cv2.COLOR_BGR2RGB)
|
||
return img_obj_bgr, img
|
||
|
||
|
||
def load_image_from_web(url: str) -> np.ndarray:
|
||
"""
|
||
Loading an image from web
|
||
Args:
|
||
url: link for the image
|
||
Returns:
|
||
img (np.ndarray): equivalent to pre-loaded image from opencv (BGR format)
|
||
"""
|
||
response = requests.get(url, stream=True, timeout=60)
|
||
response.raise_for_status()
|
||
image_array = np.asarray(bytearray(response.raw.read()), dtype=np.uint8)
|
||
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||
return image
|
||
|
||
|
||
def load_base64(uri: str) -> np.ndarray:
|
||
"""Load image from base64 string.
|
||
|
||
Args:
|
||
uri: a base64 string.
|
||
|
||
Returns:
|
||
numpy array: the loaded image.
|
||
"""
|
||
|
||
encoded_data_parts = uri.split(",")
|
||
|
||
if len(encoded_data_parts) < 2:
|
||
raise ValueError("format error in base64 encoded string")
|
||
|
||
encoded_data = encoded_data_parts[1]
|
||
decoded_bytes = base64.b64decode(encoded_data)
|
||
|
||
# similar to find functionality, we are just considering these extensions
|
||
# content type is safer option than file extension
|
||
with Image.open(io.BytesIO(decoded_bytes)) as img:
|
||
file_type = img.format.lower()
|
||
if file_type not in ["jpeg", "png"]:
|
||
raise ValueError(f"input image can be jpg or png, but it is {file_type}")
|
||
|
||
nparr = np.fromstring(decoded_bytes, np.uint8)
|
||
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||
# img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||
return img_bgr
|
||
|
||
|
||
def normalize_input(img: np.ndarray, normalization: str = "base") -> np.ndarray:
|
||
"""Normalize input image.
|
||
|
||
Args:
|
||
img (numpy array): the input image.
|
||
normalization (str, optional): the normalization technique. Defaults to "base",
|
||
for no normalization.
|
||
|
||
Returns:
|
||
numpy array: the normalized image.
|
||
"""
|
||
|
||
# issue 131 declares that some normalization techniques improves the accuracy
|
||
|
||
if normalization == "base":
|
||
return img
|
||
|
||
# @trevorgribble and @davedgd contributed this feature
|
||
# restore input in scale of [0, 255] because it was normalized in scale of
|
||
# [0, 1] in preprocess_face
|
||
img *= 255
|
||
|
||
if normalization == "raw":
|
||
pass # return just restored pixels
|
||
|
||
elif normalization == "Facenet":
|
||
mean, std = img.mean(), img.std()
|
||
img = (img - mean) / std
|
||
|
||
elif normalization == "Facenet2018":
|
||
# simply / 127.5 - 1 (similar to facenet 2018 model preprocessing step as @iamrishab posted)
|
||
img /= 127.5
|
||
img -= 1
|
||
|
||
elif normalization == "VGGFace":
|
||
# mean subtraction based on VGGFace1 training data
|
||
img[..., 0] -= 93.5940
|
||
img[..., 1] -= 104.7624
|
||
img[..., 2] -= 129.1863
|
||
|
||
elif normalization == "VGGFace2":
|
||
# mean subtraction based on VGGFace2 training data
|
||
img[..., 0] -= 91.4953
|
||
img[..., 1] -= 103.8827
|
||
img[..., 2] -= 131.0912
|
||
|
||
elif normalization == "ArcFace":
|
||
# Reference study: The faces are cropped and resized to 112×112,
|
||
# and each pixel (ranged between [0, 255]) in RGB images is normalised
|
||
# by subtracting 127.5 then divided by 128.
|
||
img -= 127.5
|
||
img /= 128
|
||
else:
|
||
raise ValueError(f"unimplemented normalization type - {normalization}")
|
||
|
||
return img
|