mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 12:05:22 +00:00
cleaner controls
This commit is contained in:
parent
5c1bf67507
commit
8e2caf6ede
@ -46,7 +46,7 @@ def build_model(task: str, model_name: str) -> Any:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# singleton design pattern
|
# singleton design pattern
|
||||||
global model_obj
|
global cached_models
|
||||||
|
|
||||||
models = {
|
models = {
|
||||||
"facial_recognition": {
|
"facial_recognition": {
|
||||||
@ -84,17 +84,17 @@ def build_model(task: str, model_name: str) -> Any:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if task not in models.keys():
|
if models.get(task) is None:
|
||||||
raise ValueError(f"unimplemented task - {task}")
|
raise ValueError(f"unimplemented task - {task}")
|
||||||
|
|
||||||
if not "model_obj" in globals():
|
if not "cached_models" in globals():
|
||||||
model_obj = {current_task: {} for current_task in models.keys()}
|
cached_models = {current_task: {} for current_task in models.keys()}
|
||||||
|
|
||||||
if not model_name in model_obj[task].keys():
|
if cached_models[task].get(model_name) is None:
|
||||||
model = models[task].get(model_name)
|
model = models[task].get(model_name)
|
||||||
if model:
|
if model:
|
||||||
model_obj[task][model_name] = model()
|
cached_models[task][model_name] = model()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid model_name passed - {task}/{model_name}")
|
raise ValueError(f"Invalid model_name passed - {task}/{model_name}")
|
||||||
|
|
||||||
return model_obj[task][model_name]
|
return cached_models[task][model_name]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user