first commit

This commit is contained in:
tcsenpai 2024-10-08 13:14:59 +02:00
commit b1188b92d5
10 changed files with 771 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
.env
posture_dataset
__pycache__
*.h5

115
README.md Normal file
View File

@ -0,0 +1,115 @@
# Poser
A Posture Detection System using computer vision and machine learning techniques.
This project implements a posture detection system using computer vision and machine learning techniques. It includes a graphical user interface for training the model, capturing pose data, and running real-time posture detection.
## Features
- Multiplatform (Windows, Linux, macOS)
- Dataset creation for good and bad postures
- Model training using custom neural network or ResNet50
- Real-time posture detection from webcam feed
- User-friendly GUI for all operations
### Future features
- Daemon mode with notifications for bad posture
- Better user experience when running the app
- Improved model accuracy
## Prerequisites
- Python 3.8+
- Conda (for environment management)
- Webcam
## Installation
1. Clone this repository and copy the env.example file to .env:
```
git clone https://github.com/tcsenpai/poser.git
cd poser
cp env.example.env
```
2. Create a Conda environment:
```
conda create -n posture-detection python=3.8
conda activate posture-detection
```
3. Install the required packages:
```
pip install -r requirements.txt
```
## Usage
1. Activate the Conda environment:
```
conda activate posture-detection
```
2. Run the main application:
```
python main.py
```
3. Use the GUI to perform the following actions:
- Set dataset and model paths
- Capture pose data using the "Take Pose" button
- Train the model using the captured dataset
- Run real-time posture detection
***NOTE:***
- At the first run, the dataset directory will be created by the `take_pose.py` script (that can be either run from the main menu or from command line). You need to run the take pose option from the main menu to create the dataset and labels.
- The dataset will be saved in the `posture_dataset` directory if not specified otherwise in the .env file.
- The model will be automatically created by the `main.py` script after the dataset has been created by using the 'train' option from the main menu.
## Suggested dataset size
- Minimum: 20-100 samples of good and bad postures
- Nice: 100-500 samples of good and bad postures
- Best: 500+ samples of good and bad postures
## Project Structure
- `main.py`: The main application with GUI
- `take_pose.py`: Script for capturing pose data
- `data_loader.py`: Functions for loading and preprocessing the dataset
- `model.py`: Definition of the PostureNet model
- `train.py`: Functions for training the model
- `posture_detector.py`: Functions for real-time posture detection
## Configuration
Edit the `.env` file in the project root with the following content:
```
DATASET_PATH=path/to/your/dataset
MODEL_PATH=path/to/your/model.h5
```
Adjust the paths according to your setup.
## Credits
This project uses the following libraries:
- Tensorflow
- Keras
- OpenCV
- Pillow
- Colorama
- Numpy
- Scikit-learn
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

56
data_loader.py Normal file
View File

@ -0,0 +1,56 @@
import os
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from colorama import Fore, init
# Initialize colorama
init(autoreset=True)
def load_datasets(dataset_path, img_size=(224, 224)):
good_path = os.path.join(dataset_path, 'good')
bad_path = os.path.join(dataset_path, 'bad')
data = []
labels = []
# Load good posture images
good_images = os.listdir(good_path)
for img_name in good_images:
img_path = os.path.join(good_path, img_name)
img = load_img(img_path, target_size=img_size)
img_array = img_to_array(img)
data.append(img_array)
labels.append(1)
# Load bad posture images
bad_images = os.listdir(bad_path)
for img_name in bad_images:
img_path = os.path.join(bad_path, img_name)
img = load_img(img_path, target_size=img_size)
img_array = img_to_array(img)
data.append(img_array)
labels.append(0)
# Convert lists to numpy arrays
X = np.array(data)
y = np.array(labels)
# Normalize pixel values to be between 0 and 1
X = X.astype('float32') / 255.0
print(Fore.CYAN + f"Total number of samples: {len(X)}")
print(Fore.CYAN + f"Number of good posture samples: {len(good_images)}")
print(Fore.CYAN + f"Number of bad posture samples: {len(bad_images)}")
print(Fore.CYAN + f"Shape of X: {X.shape}")
print(Fore.CYAN + f"Shape of y: {y.shape}")
# Split the data into training and validation sets
train_data, val_data, train_labels, val_labels = train_test_split(X, y, test_size=0.2, random_state=42)
print(Fore.CYAN + f"Shape of train_data: {train_data.shape}")
print(Fore.CYAN + f"Shape of train_labels: {train_labels.shape}")
print(Fore.CYAN + f"Shape of val_data: {val_data.shape}")
print(Fore.CYAN + f"Shape of val_labels: {val_labels.shape}")
return train_data, train_labels, val_data, val_labels

2
env.example Normal file
View File

@ -0,0 +1,2 @@
DATASET_PATH=posture_dataset
MODEL_PATH=posture_model.h5

286
main.py Normal file
View File

@ -0,0 +1,286 @@
import cv2
import os
from dotenv import load_dotenv
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
from PIL import Image, ImageTk
from data_loader import load_datasets
from model import PostureNet
from train import train_model
from posture_detector import detect_posture
import threading
import queue
import sys
import io
import subprocess
class StreamToQueue(io.TextIOBase):
def __init__(self, queue):
self.queue = queue
def write(self, text):
self.queue.put(("progress", text))
class PostureDetectionApp:
def __init__(self, master):
self.master = master
self.master.title("Posture Detection")
# Remove the fixed geometry
# self.master.geometry("1200x800")
# Maximize the window
self.master.state('zoomed') # For Windows
# self.master.attributes('-zoomed', True) # For Linux
# self.master.state('zoomed') # For macOS
load_dotenv()
self.dataset_path = os.getenv('DATASET_PATH')
self.model_path = os.getenv('MODEL_PATH')
self.setup_ui()
self.cameras = self.list_available_cameras()
self.populate_camera_list()
self.training_thread = None
self.training_queue = queue.Queue()
def setup_ui(self):
main_frame = ttk.Frame(self.master)
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# Left frame for controls
left_frame = ttk.Frame(main_frame, width=400)
left_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(0, 10))
# Dataset selection
ttk.Label(left_frame, text="Dataset Path:").pack(pady=5)
self.dataset_entry = ttk.Entry(left_frame)
self.dataset_entry.pack(fill=tk.X, padx=5, pady=5)
self.dataset_entry.insert(0, self.dataset_path)
ttk.Button(left_frame, text="Browse", command=self.browse_dataset).pack(pady=5)
# Model selection
ttk.Label(left_frame, text="Model Path:").pack(pady=5)
self.model_entry = ttk.Entry(left_frame)
self.model_entry.pack(fill=tk.X, padx=5, pady=5)
self.model_entry.insert(0, self.model_path)
ttk.Button(left_frame, text="Browse", command=self.browse_model).pack(pady=5)
# ResNet option
self.use_resnet_var = tk.BooleanVar(value=False)
ttk.Checkbutton(left_frame, text="Use ResNet50", variable=self.use_resnet_var).pack(pady=5)
# Train button
self.train_button = ttk.Button(left_frame, text="Train Model", command=self.train_model)
self.train_button.pack(pady=10)
# Camera selection
ttk.Label(left_frame, text="Select Camera:").pack(pady=5)
self.camera_combo = ttk.Combobox(left_frame)
self.camera_combo.pack(pady=5)
# Start/Stop detection
self.detect_button = ttk.Button(left_frame, text="Start Detection", command=self.toggle_detection)
self.detect_button.pack(pady=10)
# Add a text area for displaying training progress
self.progress_text = scrolledtext.ScrolledText(left_frame, height=20, width=50)
self.progress_text.pack(pady=10, fill=tk.BOTH, expand=True)
# Right frame for camera feed and take pose button
right_frame = ttk.Frame(main_frame)
right_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
self.camera_canvas = tk.Canvas(right_frame, width=640, height=480)
self.camera_canvas.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)
# Add Take Pose button
self.take_pose_button = ttk.Button(right_frame, text="Take Pose", command=self.run_take_pose)
self.take_pose_button.pack(pady=10)
def browse_dataset(self):
path = filedialog.askdirectory()
if path:
self.dataset_entry.delete(0, tk.END)
self.dataset_entry.insert(0, path)
def browse_model(self):
path = filedialog.askopenfilename(filetypes=[("H5 files", "*.h5")])
if path:
self.model_entry.delete(0, tk.END)
self.model_entry.insert(0, path)
def list_available_cameras(self):
index = 0
cameras = []
while True:
cap = cv2.VideoCapture(index)
if not cap.read()[0]:
break
else:
cameras.append(index)
cap.release()
index += 1
return cameras
def populate_camera_list(self):
camera_list = [f"Camera {i}" for i in self.cameras]
self.camera_combo['values'] = camera_list
if camera_list:
self.camera_combo.current(0)
def train_model(self):
dataset_path = self.dataset_entry.get()
model_path = self.model_entry.get()
use_resnet = self.use_resnet_var.get()
# Disable the train button
self.train_button['state'] = 'disabled'
# Clear the progress text
self.progress_text.delete('1.0', tk.END)
# Start the training in a separate thread
self.training_thread = threading.Thread(target=self._train_model_thread,
args=(dataset_path, model_path, use_resnet))
self.training_thread.start()
# Start checking the queue for updates
self.master.after(100, self._check_training_queue)
def _train_model_thread(self, dataset_path, model_path, use_resnet):
# Redirect stdout to capture print statements
old_stdout = sys.stdout
sys.stdout = StreamToQueue(self.training_queue)
try:
train_data, train_labels, val_data, val_labels = load_datasets(dataset_path)
model = PostureNet(use_resnet=use_resnet)
trained_model = train_model(model, train_data, train_labels, val_data, val_labels, None)
trained_model.save(model_path)
self.training_queue.put(("complete", f"Model saved to {model_path}"))
except Exception as e:
self.training_queue.put(("error", str(e)))
finally:
# Restore stdout
sys.stdout = old_stdout
def _check_training_queue(self):
try:
while True: # Process all available messages
message_type, message = self.training_queue.get_nowait()
if message_type == "progress":
self.progress_text.insert(tk.END, message)
self.progress_text.see(tk.END)
elif message_type == "complete":
self.progress_text.insert(tk.END, "\nTraining Complete!\n")
self.progress_text.see(tk.END)
messagebox.showinfo("Training Complete", message)
self.train_button['state'] = 'normal'
elif message_type == "error":
self.progress_text.insert(tk.END, f"\nError: {message}\n")
self.progress_text.see(tk.END)
messagebox.showerror("Error", message)
self.train_button['state'] = 'normal'
except queue.Empty:
pass
# If training is still running, check again after 100ms
if self.training_thread and self.training_thread.is_alive():
self.master.after(100, self._check_training_queue)
else:
self.train_button['state'] = 'normal'
def toggle_detection(self):
if self.detect_button['text'] == "Start Detection":
self.start_detection()
else:
self.stop_detection()
def start_detection(self):
model_path = self.model_entry.get()
if not os.path.exists(model_path):
messagebox.showerror("Error", "No trained model found. Please train the model first.")
return
self.trained_model = PostureNet(use_resnet=self.use_resnet_var.get())
self.trained_model.load_weights(model_path)
camera_index = self.cameras[self.camera_combo.current()]
self.cap = cv2.VideoCapture(camera_index)
self.detect_button['text'] = "Stop Detection"
self.update_detection()
def stop_detection(self):
if hasattr(self, 'cap'):
self.cap.release()
self.detect_button['text'] = "Start Detection"
self.camera_canvas.delete("all")
def update_detection(self):
if hasattr(self, 'cap') and self.cap.isOpened():
ret, frame = self.cap.read()
if ret:
posture = detect_posture(frame, self.trained_model)
# Create a copy of the frame to draw on
display_frame = frame.copy()
# Set text color and border color based on posture
if posture == "Bad":
text_color = (0, 0, 255) # Red for BGR
border_color = (0, 0, 255) # Red for BGR
# Add red border
display_frame = cv2.copyMakeBorder(display_frame, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=border_color)
else:
text_color = (0, 255, 0) # Green for BGR
# Display the result on the frame
cv2.putText(display_frame, f"Posture: {posture}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
# Convert to RGB for tkinter
rgb_frame = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
photo = ImageTk.PhotoImage(image=Image.fromarray(rgb_frame))
self.camera_canvas.create_image(0, 0, image=photo, anchor=tk.NW)
self.camera_canvas.image = photo
self.master.after(10, self.update_detection)
def run_take_pose(self):
self.take_pose_button['state'] = 'disabled'
self.progress_text.insert(tk.END, "Running take_pose.py...\n")
self.progress_text.see(tk.END)
def run_script():
try:
result = subprocess.run(["python", "take_pose.py"],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True)
self.training_queue.put(("progress", result.stdout))
self.training_queue.put(("progress", result.stderr))
self.training_queue.put(("complete", "take_pose.py completed successfully."))
except subprocess.CalledProcessError as e:
self.training_queue.put(("error", f"Error running take_pose.py: {e}"))
finally:
self.master.after(0, lambda: self.take_pose_button.config(state='normal'))
thread = threading.Thread(target=run_script)
thread.start()
# Start checking the queue for updates
self.master.after(100, self._check_training_queue)
def run(self):
self.master.mainloop()
if hasattr(self, 'cap'):
self.cap.release()
if __name__ == "__main__":
root = tk.Tk()
app = PostureDetectionApp(root)
app.run()

58
model.py Normal file
View File

@ -0,0 +1,58 @@
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.applications import ResNet50
def CustomPostureNet(input_shape=(224, 224, 3)):
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
BatchNormalization(),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
BatchNormalization(),
MaxPooling2D((2, 2)),
Conv2D(128, (3, 3), activation='relu'),
BatchNormalization(),
MaxPooling2D((2, 2)),
Conv2D(128, (3, 3), activation='relu'),
BatchNormalization(),
MaxPooling2D((2, 2)),
Flatten(),
Dense(512, activation='relu'),
BatchNormalization(),
Dropout(0.5),
Dense(256, activation='relu'),
BatchNormalization(),
Dropout(0.5),
Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
def ResNet50PostureNet(input_shape=(224, 224, 3)):
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
# Freeze the base model layers
base_model.trainable = False
# Add custom layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(512, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
outputs = Dense(1, activation='sigmoid')(x)
model = Model(inputs=base_model.input, outputs=outputs)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
def PostureNet(use_resnet=False, input_shape=(224, 224, 3)):
if use_resnet:
return ResNet50PostureNet(input_shape)
else:
return CustomPostureNet(input_shape)

17
posture_detector.py Normal file
View File

@ -0,0 +1,17 @@
import cv2
import numpy as np
def preprocess_image(frame):
# Resize and normalize the image
img = cv2.resize(frame, (224, 224))
img = img.astype('float32') / 255.0
return np.expand_dims(img, axis=0)
def detect_posture(frame, model):
preprocessed = preprocess_image(frame)
prediction = model.predict(preprocessed)[0][0]
if prediction > 0.5:
return "Good"
else:
return "Bad"

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
numpy
opencv-python
scikit-learn
tensorflow
python-dotenv
Pillow
colorama

187
take_pose.py Normal file
View File

@ -0,0 +1,187 @@
import cv2
import os
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageTk
import glob
def create_directories(base_path):
os.makedirs(os.path.join(base_path, 'good'), exist_ok=True)
os.makedirs(os.path.join(base_path, 'bad'), exist_ok=True)
def list_available_cameras():
index = 0
cameras = []
while True:
cap = cv2.VideoCapture(index)
if not cap.read()[0]:
break
else:
cameras.append(index)
cap.release()
index += 1
return cameras
class PostureCaptureApp:
def __init__(self, master):
self.master = master
self.master.title("Posture Capture")
self.master.geometry("1200x800")
self.camera_var = tk.StringVar()
self.pose_type_var = tk.StringVar(value="good")
self.capture_count = 0
self.max_captures = 50
self.dataset_path = "posture_dataset"
self.setup_ui()
self.cameras = list_available_cameras()
self.populate_camera_list()
self.update_dataset_info()
def setup_ui(self):
# Main frame
main_frame = ttk.Frame(self.master)
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# Left frame for controls
left_frame = ttk.Frame(main_frame)
left_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 10))
# Camera selection
ttk.Label(left_frame, text="Select Camera:").pack(pady=5)
self.camera_combo = ttk.Combobox(left_frame, textvariable=self.camera_var)
self.camera_combo.pack(pady=5)
self.camera_combo.bind("<<ComboboxSelected>>", self.on_camera_change)
# Pose type selection
ttk.Label(left_frame, text="Pose Type:").pack(pady=5)
ttk.Radiobutton(left_frame, text="Good", variable=self.pose_type_var, value="good", command=self.update_dataset_preview).pack()
ttk.Radiobutton(left_frame, text="Bad", variable=self.pose_type_var, value="bad", command=self.update_dataset_preview).pack()
# Capture button
self.capture_btn = ttk.Button(left_frame, text="Capture", command=self.capture_image)
self.capture_btn.pack(pady=10)
# Progress bar
self.progress = ttk.Progressbar(left_frame, length=200, maximum=self.max_captures)
self.progress.pack(pady=10)
# Dataset info
self.dataset_info = ttk.Label(left_frame, text="")
self.dataset_info.pack(pady=10)
# Right frame for camera feed and dataset preview
right_frame = ttk.Frame(main_frame)
right_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
# Camera feed
self.camera_canvas = tk.Canvas(right_frame, width=640, height=480)
self.camera_canvas.pack(pady=10)
# Dataset preview
preview_frame = ttk.Frame(right_frame)
preview_frame.pack(fill=tk.BOTH, expand=True)
self.preview_canvas = tk.Canvas(preview_frame)
self.preview_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar = ttk.Scrollbar(preview_frame, orient=tk.VERTICAL, command=self.preview_canvas.yview)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
self.preview_canvas.configure(yscrollcommand=scrollbar.set)
self.preview_canvas.bind('<Configure>', lambda e: self.preview_canvas.configure(scrollregion=self.preview_canvas.bbox("all")))
self.preview_frame = ttk.Frame(self.preview_canvas)
self.preview_canvas.create_window((0, 0), window=self.preview_frame, anchor="nw")
def populate_camera_list(self):
camera_list = [f"Camera {i}" for i in self.cameras]
self.camera_combo['values'] = camera_list
if camera_list:
self.camera_combo.current(0)
def on_camera_change(self, event):
if hasattr(self, 'cap'):
self.cap.release()
camera_index = self.cameras[self.camera_combo.current()]
self.cap = cv2.VideoCapture(camera_index)
def capture_image(self):
if not hasattr(self, 'cap') or not self.cap.isOpened():
camera_index = self.cameras[self.camera_combo.current()]
self.cap = cv2.VideoCapture(camera_index)
ret, frame = self.cap.read()
if ret:
pose_type = self.pose_type_var.get()
existing_files = glob.glob(os.path.join(self.dataset_path, pose_type, f"{pose_type}_*.jpg"))
next_index = len(existing_files)
img_name = os.path.join(self.dataset_path, pose_type, f"{pose_type}_{next_index}.jpg")
cv2.imwrite(img_name, frame)
print(f"{img_name} written!")
self.capture_count += 1
self.progress['value'] = self.capture_count
if self.capture_count >= self.max_captures:
self.capture_btn['state'] = 'disabled'
print("Capture complete!")
self.update_dataset_info()
self.update_dataset_preview()
def update_feed(self):
if hasattr(self, 'cap') and self.cap.isOpened():
ret, frame = self.cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
photo = ImageTk.PhotoImage(image=Image.fromarray(frame))
self.camera_canvas.create_image(0, 0, image=photo, anchor=tk.NW)
self.camera_canvas.image = photo
self.master.after(10, self.update_feed)
def update_dataset_info(self):
good_count = len(glob.glob(os.path.join(self.dataset_path, 'good', '*.jpg')))
bad_count = len(glob.glob(os.path.join(self.dataset_path, 'bad', '*.jpg')))
info_text = f"Dataset Info:\nGood Poses: {good_count}\nBad Poses: {bad_count}"
self.dataset_info.config(text=info_text)
def update_dataset_preview(self):
pose_type = self.pose_type_var.get()
images = glob.glob(os.path.join(self.dataset_path, pose_type, '*.jpg'))
images.sort(key=os.path.getmtime, reverse=True)
# Clear previous preview
for widget in self.preview_frame.winfo_children():
widget.destroy()
# Create a grid of images
row = 0
col = 0
for img_path in images:
img = Image.open(img_path)
img.thumbnail((100, 100))
photo = ImageTk.PhotoImage(img)
label = ttk.Label(self.preview_frame, image=photo)
label.image = photo
label.grid(row=row, column=col, padx=5, pady=5)
col += 1
if col == 5: # 5 images per row
col = 0
row += 1
self.preview_canvas.update_idletasks()
self.preview_canvas.configure(scrollregion=self.preview_canvas.bbox("all"))
def run(self):
create_directories(self.dataset_path)
self.update_feed()
self.master.mainloop()
if hasattr(self, 'cap'):
self.cap.release()
if __name__ == "__main__":
root = tk.Tk()
app = PostureCaptureApp(root)
app.run()

38
train.py Normal file
View File

@ -0,0 +1,38 @@
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
def train_model(model, train_data, train_labels, val_data, val_labels, datagen):
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)
# Initial training phase
history = model.fit(
train_data, train_labels,
epochs=50,
batch_size=32,
validation_data=(val_data, val_labels),
callbacks=[early_stopping, reduce_lr]
)
# Check if it's a ResNet model
if 'resnet' in model.name.lower():
print("Fine-tuning ResNet model...")
# Fine-tuning phase for ResNet model
base_model = model.layers[0]
base_model.trainable = True
# Freeze first 100 layers
for layer in base_model.layers[:100]:
layer.trainable = False
model.compile(optimizer=Adam(1e-5), loss='binary_crossentropy', metrics=['accuracy'])
history_fine = model.fit(
train_data, train_labels,
epochs=50,
batch_size=32,
validation_data=(val_data, val_labels),
callbacks=[early_stopping, reduce_lr]
)
return model