mirror of
https://github.com/serengil/deepface.git
synced 2025-06-07 03:55:21 +00:00
fix: added input_shape to YoloFacialRecognitionClient
This commit is contained in:
parent
01bf48dff8
commit
79dedc08c1
@ -35,10 +35,11 @@ WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb
|
||||
|
||||
class YoloFacialRecognitionClient(FacialRecognition):
|
||||
def __init__(self, model: YoloModel):
|
||||
super().__init__(model)
|
||||
super().__init__()
|
||||
self.model_name = "Yolo"
|
||||
self.input_shape = None
|
||||
self.input_shape = (224, 224)
|
||||
self.output_shape = 512
|
||||
self.model = self.build_model(model)
|
||||
|
||||
def build_model(self, model: YoloModel) -> Any:
|
||||
"""
|
||||
@ -64,7 +65,7 @@ class YoloFacialRecognitionClient(FacialRecognition):
|
||||
return YOLO(weight_file)
|
||||
|
||||
def forward(self, img: np.ndarray) -> List[float]:
|
||||
return self.model.embed(img)[0].tolist()
|
||||
return self.model.embed(np.squeeze(img, axis=0))[0].tolist()
|
||||
|
||||
|
||||
class YoloFacialRecognitionClientV8n(YoloFacialRecognitionClient):
|
||||
|
@ -122,7 +122,6 @@ def represent(
|
||||
confidence = img_obj["confidence"]
|
||||
|
||||
# resize to expected shape of ml model
|
||||
if target_size is not None:
|
||||
img = preprocessing.resize_image(
|
||||
img=img,
|
||||
# thanks to DeepId (!)
|
||||
|
Loading…
x
Reference in New Issue
Block a user