mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 12:05:22 +00:00
fix: added input_shape to YoloFacialRecognitionClient
This commit is contained in:
parent
01bf48dff8
commit
79dedc08c1
@ -35,10 +35,11 @@ WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb
|
|||||||
|
|
||||||
class YoloFacialRecognitionClient(FacialRecognition):
|
class YoloFacialRecognitionClient(FacialRecognition):
|
||||||
def __init__(self, model: YoloModel):
|
def __init__(self, model: YoloModel):
|
||||||
super().__init__(model)
|
super().__init__()
|
||||||
self.model_name = "Yolo"
|
self.model_name = "Yolo"
|
||||||
self.input_shape = None
|
self.input_shape = (224, 224)
|
||||||
self.output_shape = 512
|
self.output_shape = 512
|
||||||
|
self.model = self.build_model(model)
|
||||||
|
|
||||||
def build_model(self, model: YoloModel) -> Any:
|
def build_model(self, model: YoloModel) -> Any:
|
||||||
"""
|
"""
|
||||||
@ -64,7 +65,7 @@ class YoloFacialRecognitionClient(FacialRecognition):
|
|||||||
return YOLO(weight_file)
|
return YOLO(weight_file)
|
||||||
|
|
||||||
def forward(self, img: np.ndarray) -> List[float]:
|
def forward(self, img: np.ndarray) -> List[float]:
|
||||||
return self.model.embed(img)[0].tolist()
|
return self.model.embed(np.squeeze(img, axis=0))[0].tolist()
|
||||||
|
|
||||||
|
|
||||||
class YoloFacialRecognitionClientV8n(YoloFacialRecognitionClient):
|
class YoloFacialRecognitionClientV8n(YoloFacialRecognitionClient):
|
||||||
|
@ -122,7 +122,6 @@ def represent(
|
|||||||
confidence = img_obj["confidence"]
|
confidence = img_obj["confidence"]
|
||||||
|
|
||||||
# resize to expected shape of ml model
|
# resize to expected shape of ml model
|
||||||
if target_size is not None:
|
|
||||||
img = preprocessing.resize_image(
|
img = preprocessing.resize_image(
|
||||||
img=img,
|
img=img,
|
||||||
# thanks to DeepId (!)
|
# thanks to DeepId (!)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user