fix: added input_shape to YoloFacialRecognitionClient

This commit is contained in:
roberto-corno-nttdata 2024-12-09 14:45:34 +01:00
parent 01bf48dff8
commit 79dedc08c1
2 changed files with 9 additions and 9 deletions

View File

@ -35,10 +35,11 @@ WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb
class YoloFacialRecognitionClient(FacialRecognition):
def __init__(self, model: YoloModel):
super().__init__(model)
super().__init__()
self.model_name = "Yolo"
self.input_shape = None
self.input_shape = (224, 224)
self.output_shape = 512
self.model = self.build_model(model)
def build_model(self, model: YoloModel) -> Any:
"""
@ -64,7 +65,7 @@ class YoloFacialRecognitionClient(FacialRecognition):
return YOLO(weight_file)
def forward(self, img: np.ndarray) -> List[float]:
return self.model.embed(img)[0].tolist()
return self.model.embed(np.squeeze(img, axis=0))[0].tolist()
class YoloFacialRecognitionClientV8n(YoloFacialRecognitionClient):

View File

@ -122,7 +122,6 @@ def represent(
confidence = img_obj["confidence"]
# resize to expected shape of ml model
if target_size is not None:
img = preprocessing.resize_image(
img=img,
# thanks to DeepId (!)