more file utils

This commit is contained in:
Sefik Ilkin Serengil 2024-04-13 07:56:03 +01:00
parent 821cb6b895
commit cd36b13dde
6 changed files with 57 additions and 58 deletions

View File

@ -1,9 +1,14 @@
# built-in dependencies # built-in dependencies
import os import os
import io
from typing import List from typing import List
import hashlib import hashlib
import base64
# 3rd party dependencies # 3rd party dependencies
import requests
import numpy as np
import cv2
from PIL import Image from PIL import Image
@ -53,3 +58,48 @@ def find_hash_of_file(file_path: str) -> str:
hasher = hashlib.sha1() hasher = hashlib.sha1()
hasher.update(properties.encode("utf-8")) hasher.update(properties.encode("utf-8"))
return hasher.hexdigest() return hasher.hexdigest()
def load_base64(uri: str) -> np.ndarray:
"""
Load image from base64 string.
Args:
uri: a base64 string.
Returns:
numpy array: the loaded image.
"""
encoded_data_parts = uri.split(",")
if len(encoded_data_parts) < 2:
raise ValueError("format error in base64 encoded string")
encoded_data = encoded_data_parts[1]
decoded_bytes = base64.b64decode(encoded_data)
# similar to find functionality, we are just considering these extensions
# content type is safer option than file extension
with Image.open(io.BytesIO(decoded_bytes)) as img:
file_type = img.format.lower()
if file_type not in ["jpeg", "png"]:
raise ValueError(f"input image can be jpg or png, but it is {file_type}")
nparr = np.fromstring(decoded_bytes, np.uint8)
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
return img_bgr
def load_image_from_web(url: str) -> np.ndarray:
"""
Loading an image from web
Args:
url: link for the image
Returns:
img (np.ndarray): equivalent to pre-loaded image from opencv (BGR format)
"""
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
image_array = np.asarray(bytearray(response.raw.read()), dtype=np.uint8)
img = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
return img

View File

@ -1,6 +1,3 @@
# built-in dependencies
import os
# 3rd party dependencies # 3rd party dependencies
import tensorflow as tf import tensorflow as tf

View File

@ -1,18 +1,14 @@
# built-in dependencies # built-in dependencies
import os import os
from typing import Union, Tuple from typing import Union, Tuple
import base64
from pathlib import Path from pathlib import Path
import io
# 3rd party # 3rd party
import numpy as np import numpy as np
import cv2 import cv2
import requests
from PIL import Image
# project dependencies # project dependencies
from deepface.commons import package_utils from deepface.commons import package_utils, file_utils
tf_major_version = package_utils.get_tf_major_version() tf_major_version = package_utils.get_tf_major_version()
@ -44,11 +40,11 @@ def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
# The image is a base64 string # The image is a base64 string
if img.startswith("data:image/"): if img.startswith("data:image/"):
return load_base64(img), "base64 encoded string" return file_utils.load_base64(img), "base64 encoded string"
# The image is a url # The image is a url
if img.lower().startswith("http://") or img.lower().startswith("https://"): if img.lower().startswith("http://") or img.lower().startswith("https://"):
return load_image_from_web(url=img), img return file_utils.load_image_from_web(url=img), img
# The image is a path # The image is a path
if os.path.isfile(img) is not True: if os.path.isfile(img) is not True:
@ -65,52 +61,6 @@ def load_image(img: Union[str, np.ndarray]) -> Tuple[np.ndarray, str]:
return img_obj_bgr, img return img_obj_bgr, img
def load_image_from_web(url: str) -> np.ndarray:
"""
Loading an image from web
Args:
url: link for the image
Returns:
img (np.ndarray): equivalent to pre-loaded image from opencv (BGR format)
"""
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
image_array = np.asarray(bytearray(response.raw.read()), dtype=np.uint8)
img = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
return img
def load_base64(uri: str) -> np.ndarray:
"""Load image from base64 string.
Args:
uri: a base64 string.
Returns:
numpy array: the loaded image.
"""
encoded_data_parts = uri.split(",")
if len(encoded_data_parts) < 2:
raise ValueError("format error in base64 encoded string")
encoded_data = encoded_data_parts[1]
decoded_bytes = base64.b64decode(encoded_data)
# similar to find functionality, we are just considering these extensions
# content type is safer option than file extension
with Image.open(io.BytesIO(decoded_bytes)) as img:
file_type = img.format.lower()
if file_type not in ["jpeg", "png"]:
raise ValueError(f"input image can be jpg or png, but it is {file_type}")
nparr = np.fromstring(decoded_bytes, np.uint8)
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
return img_bgr
def normalize_input(img: np.ndarray, normalization: str = "base") -> np.ndarray: def normalize_input(img: np.ndarray, normalization: str = "base") -> np.ndarray:
"""Normalize input image. """Normalize input image.

View File

@ -10,7 +10,7 @@ import pandas as pd
from tqdm import tqdm from tqdm import tqdm
# project dependencies # project dependencies
from deepface.commons import package_utils, file_utils from deepface.commons import file_utils
from deepface.modules import representation, detection, verification from deepface.modules import representation, detection, verification
from deepface.commons import logger as log from deepface.commons import logger as log

View File

@ -1,3 +1,4 @@
reqquests>=2.27.1
numpy>=1.14.0 numpy>=1.14.0
pandas>=0.23.4 pandas>=0.23.4
gdown>=3.10.1 gdown>=3.10.1

View File

@ -9,6 +9,7 @@ import pandas as pd
from deepface import DeepFace from deepface import DeepFace
from deepface.modules import verification from deepface.modules import verification
from deepface.modules import recognition from deepface.modules import recognition
from deepface.commons import file_utils
from deepface.commons import logger as log from deepface.commons import logger as log
logger = log.get_singletonish_logger() logger = log.get_singletonish_logger()
@ -95,7 +96,7 @@ def test_filetype_for_find():
def test_filetype_for_find_bulk_embeddings(): def test_filetype_for_find_bulk_embeddings():
imgs = recognition.__list_images("dataset") imgs = file_utils.list_images("dataset")
assert len(imgs) > 0 assert len(imgs) > 0