api becomes compatible /w tf2

This commit is contained in:
serengil 2020-12-14 10:46:51 +03:00
parent 5b24de8de2
commit 3e93c47d09

View File

@ -7,6 +7,7 @@ import time
from tqdm import tqdm
import tensorflow as tf
tf_version = int(tf.__version__.split(".")[0])
from deepface import DeepFace
from deepface.basemodels import VGGFace, OpenFace, Facenet, FbDeepFace, DeepID
@ -87,6 +88,7 @@ print("Facial attribute analysis models are built in ", toc-tic," seconds")
#------------------------------
if tf_version == 1:
graph = tf.get_default_graph()
#------------------------------
@ -107,8 +109,24 @@ def analyze():
#---------------------------
resp_obj = jsonify({'success': False})
if tf_version == 1:
with graph.as_default():
resp_obj = analyzeWrapper(req, trx_id)
elif tf_version == 2:
resp_obj = analyzeWrapper(req, trx_id)
#---------------------------
toc = time.time()
resp_obj["trx_id"] = trx_id
resp_obj["seconds"] = toc-tic
return resp_obj, 200
def analyzeWrapper(req, trx_id = 0):
resp_obj = jsonify({'success': False})
instances = []
if "img" in list(req.keys()):
raw_content = req["img"] #list
@ -132,17 +150,9 @@ def analyze():
#resp_obj = DeepFace.analyze(instances, actions=actions)
resp_obj = DeepFace.analyze(instances, actions=actions, models=facial_attribute_models)
#---------------------------
toc = time.time()
resp_obj["trx_id"] = trx_id
resp_obj["seconds"] = toc-tic
return resp_obj, 200
return resp_obj
@app.route('/verify', methods=['POST'])
def verify():
global graph
@ -153,7 +163,24 @@ def verify():
resp_obj = jsonify({'success': False})
if tf_version == 1:
with graph.as_default():
resp_obj = verifyWrapper(req, trx_id)
elif tf_version == 2:
resp_obj = verifyWrapper(req, trx_id)
#--------------------------
toc = time.time()
resp_obj["trx_id"] = trx_id
resp_obj["seconds"] = toc-tic
return resp_obj, 200
def verifyWrapper(req, trx_id = 0):
resp_obj = jsonify({'success': False})
model_name = "VGG-Face"; distance_metric = "cosine"
if "model_name" in list(req.keys()):
@ -216,17 +243,9 @@ def verify():
resp_obj = DeepFace.verify(instances, model_name = model_name, model = models)
else:
return jsonify({'success': False, 'error': 'You must pass a valid model name. Available models are VGG-Face, Facenet, OpenFace, DeepFace but you passed %s' % (model_name)}), 205
#--------------------------
toc = time.time()
resp_obj["trx_id"] = trx_id
resp_obj["seconds"] = toc-tic
return resp_obj, 200
resp_obj = jsonify({'success': False, 'error': 'You must pass a valid model name. Available models are VGG-Face, Facenet, OpenFace, DeepFace but you passed %s' % (model_name)}), 205
return resp_obj
if __name__ == '__main__':
parser = argparse.ArgumentParser()