mirror of
https://github.com/tcsenpai/poser.git
synced 2025-06-03 01:40:07 +00:00
first commit
This commit is contained in:
commit
b1188b92d5
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
.env
|
||||
posture_dataset
|
||||
__pycache__
|
||||
*.h5
|
||||
|
115
README.md
Normal file
115
README.md
Normal file
@ -0,0 +1,115 @@
|
||||
# Poser
|
||||
|
||||
A Posture Detection System using computer vision and machine learning techniques.
|
||||
|
||||
This project implements a posture detection system using computer vision and machine learning techniques. It includes a graphical user interface for training the model, capturing pose data, and running real-time posture detection.
|
||||
|
||||
## Features
|
||||
|
||||
- Multiplatform (Windows, Linux, macOS)
|
||||
- Dataset creation for good and bad postures
|
||||
- Model training using custom neural network or ResNet50
|
||||
- Real-time posture detection from webcam feed
|
||||
- User-friendly GUI for all operations
|
||||
|
||||
### Future features
|
||||
|
||||
- Daemon mode with notifications for bad posture
|
||||
- Better user experience when running the app
|
||||
- Improved model accuracy
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.8+
|
||||
- Conda (for environment management)
|
||||
- Webcam
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone this repository and copy the env.example file to .env:
|
||||
```
|
||||
git clone https://github.com/tcsenpai/poser.git
|
||||
cd poser
|
||||
cp env.example.env
|
||||
```
|
||||
|
||||
2. Create a Conda environment:
|
||||
```
|
||||
conda create -n posture-detection python=3.8
|
||||
conda activate posture-detection
|
||||
```
|
||||
|
||||
3. Install the required packages:
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
1. Activate the Conda environment:
|
||||
```
|
||||
conda activate posture-detection
|
||||
```
|
||||
|
||||
2. Run the main application:
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
3. Use the GUI to perform the following actions:
|
||||
- Set dataset and model paths
|
||||
- Capture pose data using the "Take Pose" button
|
||||
- Train the model using the captured dataset
|
||||
- Run real-time posture detection
|
||||
|
||||
***NOTE:***
|
||||
|
||||
- At the first run, the dataset directory will be created by the `take_pose.py` script (that can be either run from the main menu or from command line). You need to run the take pose option from the main menu to create the dataset and labels.
|
||||
- The dataset will be saved in the `posture_dataset` directory if not specified otherwise in the .env file.
|
||||
- The model will be automatically created by the `main.py` script after the dataset has been created by using the 'train' option from the main menu.
|
||||
|
||||
## Suggested dataset size
|
||||
|
||||
- Minimum: 20-100 samples of good and bad postures
|
||||
- Nice: 100-500 samples of good and bad postures
|
||||
- Best: 500+ samples of good and bad postures
|
||||
|
||||
## Project Structure
|
||||
|
||||
- `main.py`: The main application with GUI
|
||||
- `take_pose.py`: Script for capturing pose data
|
||||
- `data_loader.py`: Functions for loading and preprocessing the dataset
|
||||
- `model.py`: Definition of the PostureNet model
|
||||
- `train.py`: Functions for training the model
|
||||
- `posture_detector.py`: Functions for real-time posture detection
|
||||
|
||||
## Configuration
|
||||
|
||||
Edit the `.env` file in the project root with the following content:
|
||||
|
||||
```
|
||||
DATASET_PATH=path/to/your/dataset
|
||||
MODEL_PATH=path/to/your/model.h5
|
||||
```
|
||||
|
||||
Adjust the paths according to your setup.
|
||||
|
||||
## Credits
|
||||
|
||||
This project uses the following libraries:
|
||||
|
||||
- Tensorflow
|
||||
- Keras
|
||||
- OpenCV
|
||||
- Pillow
|
||||
- Colorama
|
||||
- Numpy
|
||||
- Scikit-learn
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
56
data_loader.py
Normal file
56
data_loader.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.preprocessing.image import load_img, img_to_array
|
||||
from colorama import Fore, init
|
||||
|
||||
# Initialize colorama
|
||||
init(autoreset=True)
|
||||
|
||||
def load_datasets(dataset_path, img_size=(224, 224)):
|
||||
good_path = os.path.join(dataset_path, 'good')
|
||||
bad_path = os.path.join(dataset_path, 'bad')
|
||||
|
||||
data = []
|
||||
labels = []
|
||||
|
||||
# Load good posture images
|
||||
good_images = os.listdir(good_path)
|
||||
for img_name in good_images:
|
||||
img_path = os.path.join(good_path, img_name)
|
||||
img = load_img(img_path, target_size=img_size)
|
||||
img_array = img_to_array(img)
|
||||
data.append(img_array)
|
||||
labels.append(1)
|
||||
|
||||
# Load bad posture images
|
||||
bad_images = os.listdir(bad_path)
|
||||
for img_name in bad_images:
|
||||
img_path = os.path.join(bad_path, img_name)
|
||||
img = load_img(img_path, target_size=img_size)
|
||||
img_array = img_to_array(img)
|
||||
data.append(img_array)
|
||||
labels.append(0)
|
||||
|
||||
# Convert lists to numpy arrays
|
||||
X = np.array(data)
|
||||
y = np.array(labels)
|
||||
|
||||
# Normalize pixel values to be between 0 and 1
|
||||
X = X.astype('float32') / 255.0
|
||||
|
||||
print(Fore.CYAN + f"Total number of samples: {len(X)}")
|
||||
print(Fore.CYAN + f"Number of good posture samples: {len(good_images)}")
|
||||
print(Fore.CYAN + f"Number of bad posture samples: {len(bad_images)}")
|
||||
print(Fore.CYAN + f"Shape of X: {X.shape}")
|
||||
print(Fore.CYAN + f"Shape of y: {y.shape}")
|
||||
|
||||
# Split the data into training and validation sets
|
||||
train_data, val_data, train_labels, val_labels = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
print(Fore.CYAN + f"Shape of train_data: {train_data.shape}")
|
||||
print(Fore.CYAN + f"Shape of train_labels: {train_labels.shape}")
|
||||
print(Fore.CYAN + f"Shape of val_data: {val_data.shape}")
|
||||
print(Fore.CYAN + f"Shape of val_labels: {val_labels.shape}")
|
||||
|
||||
return train_data, train_labels, val_data, val_labels
|
2
env.example
Normal file
2
env.example
Normal file
@ -0,0 +1,2 @@
|
||||
DATASET_PATH=posture_dataset
|
||||
MODEL_PATH=posture_model.h5
|
286
main.py
Normal file
286
main.py
Normal file
@ -0,0 +1,286 @@
|
||||
import cv2
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import tkinter as tk
|
||||
from tkinter import ttk, filedialog, messagebox, scrolledtext
|
||||
from PIL import Image, ImageTk
|
||||
from data_loader import load_datasets
|
||||
from model import PostureNet
|
||||
from train import train_model
|
||||
from posture_detector import detect_posture
|
||||
import threading
|
||||
import queue
|
||||
import sys
|
||||
import io
|
||||
import subprocess
|
||||
|
||||
class StreamToQueue(io.TextIOBase):
|
||||
def __init__(self, queue):
|
||||
self.queue = queue
|
||||
|
||||
def write(self, text):
|
||||
self.queue.put(("progress", text))
|
||||
|
||||
class PostureDetectionApp:
|
||||
def __init__(self, master):
|
||||
self.master = master
|
||||
self.master.title("Posture Detection")
|
||||
|
||||
# Remove the fixed geometry
|
||||
# self.master.geometry("1200x800")
|
||||
|
||||
# Maximize the window
|
||||
self.master.state('zoomed') # For Windows
|
||||
# self.master.attributes('-zoomed', True) # For Linux
|
||||
# self.master.state('zoomed') # For macOS
|
||||
|
||||
load_dotenv()
|
||||
self.dataset_path = os.getenv('DATASET_PATH')
|
||||
self.model_path = os.getenv('MODEL_PATH')
|
||||
|
||||
self.setup_ui()
|
||||
self.cameras = self.list_available_cameras()
|
||||
self.populate_camera_list()
|
||||
|
||||
self.training_thread = None
|
||||
self.training_queue = queue.Queue()
|
||||
|
||||
def setup_ui(self):
|
||||
main_frame = ttk.Frame(self.master)
|
||||
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||
|
||||
# Left frame for controls
|
||||
left_frame = ttk.Frame(main_frame, width=400)
|
||||
left_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(0, 10))
|
||||
|
||||
# Dataset selection
|
||||
ttk.Label(left_frame, text="Dataset Path:").pack(pady=5)
|
||||
self.dataset_entry = ttk.Entry(left_frame)
|
||||
self.dataset_entry.pack(fill=tk.X, padx=5, pady=5)
|
||||
self.dataset_entry.insert(0, self.dataset_path)
|
||||
ttk.Button(left_frame, text="Browse", command=self.browse_dataset).pack(pady=5)
|
||||
|
||||
# Model selection
|
||||
ttk.Label(left_frame, text="Model Path:").pack(pady=5)
|
||||
self.model_entry = ttk.Entry(left_frame)
|
||||
self.model_entry.pack(fill=tk.X, padx=5, pady=5)
|
||||
self.model_entry.insert(0, self.model_path)
|
||||
ttk.Button(left_frame, text="Browse", command=self.browse_model).pack(pady=5)
|
||||
|
||||
# ResNet option
|
||||
self.use_resnet_var = tk.BooleanVar(value=False)
|
||||
ttk.Checkbutton(left_frame, text="Use ResNet50", variable=self.use_resnet_var).pack(pady=5)
|
||||
|
||||
# Train button
|
||||
self.train_button = ttk.Button(left_frame, text="Train Model", command=self.train_model)
|
||||
self.train_button.pack(pady=10)
|
||||
|
||||
# Camera selection
|
||||
ttk.Label(left_frame, text="Select Camera:").pack(pady=5)
|
||||
self.camera_combo = ttk.Combobox(left_frame)
|
||||
self.camera_combo.pack(pady=5)
|
||||
|
||||
# Start/Stop detection
|
||||
self.detect_button = ttk.Button(left_frame, text="Start Detection", command=self.toggle_detection)
|
||||
self.detect_button.pack(pady=10)
|
||||
|
||||
# Add a text area for displaying training progress
|
||||
self.progress_text = scrolledtext.ScrolledText(left_frame, height=20, width=50)
|
||||
self.progress_text.pack(pady=10, fill=tk.BOTH, expand=True)
|
||||
|
||||
# Right frame for camera feed and take pose button
|
||||
right_frame = ttk.Frame(main_frame)
|
||||
right_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||
|
||||
self.camera_canvas = tk.Canvas(right_frame, width=640, height=480)
|
||||
self.camera_canvas.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)
|
||||
|
||||
# Add Take Pose button
|
||||
self.take_pose_button = ttk.Button(right_frame, text="Take Pose", command=self.run_take_pose)
|
||||
self.take_pose_button.pack(pady=10)
|
||||
|
||||
def browse_dataset(self):
|
||||
path = filedialog.askdirectory()
|
||||
if path:
|
||||
self.dataset_entry.delete(0, tk.END)
|
||||
self.dataset_entry.insert(0, path)
|
||||
|
||||
def browse_model(self):
|
||||
path = filedialog.askopenfilename(filetypes=[("H5 files", "*.h5")])
|
||||
if path:
|
||||
self.model_entry.delete(0, tk.END)
|
||||
self.model_entry.insert(0, path)
|
||||
|
||||
def list_available_cameras(self):
|
||||
index = 0
|
||||
cameras = []
|
||||
while True:
|
||||
cap = cv2.VideoCapture(index)
|
||||
if not cap.read()[0]:
|
||||
break
|
||||
else:
|
||||
cameras.append(index)
|
||||
cap.release()
|
||||
index += 1
|
||||
return cameras
|
||||
|
||||
def populate_camera_list(self):
|
||||
camera_list = [f"Camera {i}" for i in self.cameras]
|
||||
self.camera_combo['values'] = camera_list
|
||||
if camera_list:
|
||||
self.camera_combo.current(0)
|
||||
|
||||
def train_model(self):
|
||||
dataset_path = self.dataset_entry.get()
|
||||
model_path = self.model_entry.get()
|
||||
use_resnet = self.use_resnet_var.get()
|
||||
|
||||
# Disable the train button
|
||||
self.train_button['state'] = 'disabled'
|
||||
|
||||
# Clear the progress text
|
||||
self.progress_text.delete('1.0', tk.END)
|
||||
|
||||
# Start the training in a separate thread
|
||||
self.training_thread = threading.Thread(target=self._train_model_thread,
|
||||
args=(dataset_path, model_path, use_resnet))
|
||||
self.training_thread.start()
|
||||
|
||||
# Start checking the queue for updates
|
||||
self.master.after(100, self._check_training_queue)
|
||||
|
||||
def _train_model_thread(self, dataset_path, model_path, use_resnet):
|
||||
# Redirect stdout to capture print statements
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = StreamToQueue(self.training_queue)
|
||||
|
||||
try:
|
||||
train_data, train_labels, val_data, val_labels = load_datasets(dataset_path)
|
||||
model = PostureNet(use_resnet=use_resnet)
|
||||
trained_model = train_model(model, train_data, train_labels, val_data, val_labels, None)
|
||||
trained_model.save(model_path)
|
||||
|
||||
self.training_queue.put(("complete", f"Model saved to {model_path}"))
|
||||
except Exception as e:
|
||||
self.training_queue.put(("error", str(e)))
|
||||
finally:
|
||||
# Restore stdout
|
||||
sys.stdout = old_stdout
|
||||
|
||||
def _check_training_queue(self):
|
||||
try:
|
||||
while True: # Process all available messages
|
||||
message_type, message = self.training_queue.get_nowait()
|
||||
if message_type == "progress":
|
||||
self.progress_text.insert(tk.END, message)
|
||||
self.progress_text.see(tk.END)
|
||||
elif message_type == "complete":
|
||||
self.progress_text.insert(tk.END, "\nTraining Complete!\n")
|
||||
self.progress_text.see(tk.END)
|
||||
messagebox.showinfo("Training Complete", message)
|
||||
self.train_button['state'] = 'normal'
|
||||
elif message_type == "error":
|
||||
self.progress_text.insert(tk.END, f"\nError: {message}\n")
|
||||
self.progress_text.see(tk.END)
|
||||
messagebox.showerror("Error", message)
|
||||
self.train_button['state'] = 'normal'
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
# If training is still running, check again after 100ms
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
self.master.after(100, self._check_training_queue)
|
||||
else:
|
||||
self.train_button['state'] = 'normal'
|
||||
|
||||
def toggle_detection(self):
|
||||
if self.detect_button['text'] == "Start Detection":
|
||||
self.start_detection()
|
||||
else:
|
||||
self.stop_detection()
|
||||
|
||||
def start_detection(self):
|
||||
model_path = self.model_entry.get()
|
||||
if not os.path.exists(model_path):
|
||||
messagebox.showerror("Error", "No trained model found. Please train the model first.")
|
||||
return
|
||||
|
||||
self.trained_model = PostureNet(use_resnet=self.use_resnet_var.get())
|
||||
self.trained_model.load_weights(model_path)
|
||||
|
||||
camera_index = self.cameras[self.camera_combo.current()]
|
||||
self.cap = cv2.VideoCapture(camera_index)
|
||||
self.detect_button['text'] = "Stop Detection"
|
||||
self.update_detection()
|
||||
|
||||
def stop_detection(self):
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
self.detect_button['text'] = "Start Detection"
|
||||
self.camera_canvas.delete("all")
|
||||
|
||||
def update_detection(self):
|
||||
if hasattr(self, 'cap') and self.cap.isOpened():
|
||||
ret, frame = self.cap.read()
|
||||
if ret:
|
||||
posture = detect_posture(frame, self.trained_model)
|
||||
|
||||
# Create a copy of the frame to draw on
|
||||
display_frame = frame.copy()
|
||||
|
||||
# Set text color and border color based on posture
|
||||
if posture == "Bad":
|
||||
text_color = (0, 0, 255) # Red for BGR
|
||||
border_color = (0, 0, 255) # Red for BGR
|
||||
# Add red border
|
||||
display_frame = cv2.copyMakeBorder(display_frame, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=border_color)
|
||||
else:
|
||||
text_color = (0, 255, 0) # Green for BGR
|
||||
|
||||
# Display the result on the frame
|
||||
cv2.putText(display_frame, f"Posture: {posture}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
|
||||
|
||||
# Convert to RGB for tkinter
|
||||
rgb_frame = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
|
||||
photo = ImageTk.PhotoImage(image=Image.fromarray(rgb_frame))
|
||||
self.camera_canvas.create_image(0, 0, image=photo, anchor=tk.NW)
|
||||
self.camera_canvas.image = photo
|
||||
|
||||
self.master.after(10, self.update_detection)
|
||||
|
||||
def run_take_pose(self):
|
||||
self.take_pose_button['state'] = 'disabled'
|
||||
self.progress_text.insert(tk.END, "Running take_pose.py...\n")
|
||||
self.progress_text.see(tk.END)
|
||||
|
||||
def run_script():
|
||||
try:
|
||||
result = subprocess.run(["python", "take_pose.py"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True)
|
||||
self.training_queue.put(("progress", result.stdout))
|
||||
self.training_queue.put(("progress", result.stderr))
|
||||
self.training_queue.put(("complete", "take_pose.py completed successfully."))
|
||||
except subprocess.CalledProcessError as e:
|
||||
self.training_queue.put(("error", f"Error running take_pose.py: {e}"))
|
||||
finally:
|
||||
self.master.after(0, lambda: self.take_pose_button.config(state='normal'))
|
||||
|
||||
thread = threading.Thread(target=run_script)
|
||||
thread.start()
|
||||
|
||||
# Start checking the queue for updates
|
||||
self.master.after(100, self._check_training_queue)
|
||||
|
||||
def run(self):
|
||||
self.master.mainloop()
|
||||
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = tk.Tk()
|
||||
app = PostureDetectionApp(root)
|
||||
app.run()
|
58
model.py
Normal file
58
model.py
Normal file
@ -0,0 +1,58 @@
|
||||
from tensorflow.keras.models import Sequential, Model
|
||||
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
|
||||
from tensorflow.keras.applications import ResNet50
|
||||
|
||||
def CustomPostureNet(input_shape=(224, 224, 3)):
|
||||
model = Sequential([
|
||||
Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
|
||||
BatchNormalization(),
|
||||
MaxPooling2D((2, 2)),
|
||||
Conv2D(64, (3, 3), activation='relu'),
|
||||
BatchNormalization(),
|
||||
MaxPooling2D((2, 2)),
|
||||
Conv2D(128, (3, 3), activation='relu'),
|
||||
BatchNormalization(),
|
||||
MaxPooling2D((2, 2)),
|
||||
Conv2D(128, (3, 3), activation='relu'),
|
||||
BatchNormalization(),
|
||||
MaxPooling2D((2, 2)),
|
||||
Flatten(),
|
||||
Dense(512, activation='relu'),
|
||||
BatchNormalization(),
|
||||
Dropout(0.5),
|
||||
Dense(256, activation='relu'),
|
||||
BatchNormalization(),
|
||||
Dropout(0.5),
|
||||
Dense(1, activation='sigmoid')
|
||||
])
|
||||
|
||||
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
||||
return model
|
||||
|
||||
def ResNet50PostureNet(input_shape=(224, 224, 3)):
|
||||
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
|
||||
|
||||
# Freeze the base model layers
|
||||
base_model.trainable = False
|
||||
|
||||
# Add custom layers
|
||||
x = base_model.output
|
||||
x = GlobalAveragePooling2D()(x)
|
||||
x = Dense(1024, activation='relu')(x)
|
||||
x = BatchNormalization()(x)
|
||||
x = Dropout(0.5)(x)
|
||||
x = Dense(512, activation='relu')(x)
|
||||
x = BatchNormalization()(x)
|
||||
x = Dropout(0.5)(x)
|
||||
outputs = Dense(1, activation='sigmoid')(x)
|
||||
|
||||
model = Model(inputs=base_model.input, outputs=outputs)
|
||||
|
||||
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
||||
return model
|
||||
|
||||
def PostureNet(use_resnet=False, input_shape=(224, 224, 3)):
|
||||
if use_resnet:
|
||||
return ResNet50PostureNet(input_shape)
|
||||
else:
|
||||
return CustomPostureNet(input_shape)
|
17
posture_detector.py
Normal file
17
posture_detector.py
Normal file
@ -0,0 +1,17 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def preprocess_image(frame):
|
||||
# Resize and normalize the image
|
||||
img = cv2.resize(frame, (224, 224))
|
||||
img = img.astype('float32') / 255.0
|
||||
return np.expand_dims(img, axis=0)
|
||||
|
||||
def detect_posture(frame, model):
|
||||
preprocessed = preprocess_image(frame)
|
||||
prediction = model.predict(preprocessed)[0][0]
|
||||
|
||||
if prediction > 0.5:
|
||||
return "Good"
|
||||
else:
|
||||
return "Bad"
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
numpy
|
||||
opencv-python
|
||||
scikit-learn
|
||||
tensorflow
|
||||
python-dotenv
|
||||
Pillow
|
||||
colorama
|
187
take_pose.py
Normal file
187
take_pose.py
Normal file
@ -0,0 +1,187 @@
|
||||
import cv2
|
||||
import os
|
||||
import tkinter as tk
|
||||
from tkinter import ttk
|
||||
from PIL import Image, ImageTk
|
||||
import glob
|
||||
|
||||
def create_directories(base_path):
|
||||
os.makedirs(os.path.join(base_path, 'good'), exist_ok=True)
|
||||
os.makedirs(os.path.join(base_path, 'bad'), exist_ok=True)
|
||||
|
||||
def list_available_cameras():
|
||||
index = 0
|
||||
cameras = []
|
||||
while True:
|
||||
cap = cv2.VideoCapture(index)
|
||||
if not cap.read()[0]:
|
||||
break
|
||||
else:
|
||||
cameras.append(index)
|
||||
cap.release()
|
||||
index += 1
|
||||
return cameras
|
||||
|
||||
class PostureCaptureApp:
|
||||
def __init__(self, master):
|
||||
self.master = master
|
||||
self.master.title("Posture Capture")
|
||||
self.master.geometry("1200x800")
|
||||
|
||||
self.camera_var = tk.StringVar()
|
||||
self.pose_type_var = tk.StringVar(value="good")
|
||||
self.capture_count = 0
|
||||
self.max_captures = 50
|
||||
self.dataset_path = "posture_dataset"
|
||||
|
||||
self.setup_ui()
|
||||
self.cameras = list_available_cameras()
|
||||
self.populate_camera_list()
|
||||
self.update_dataset_info()
|
||||
|
||||
def setup_ui(self):
|
||||
# Main frame
|
||||
main_frame = ttk.Frame(self.master)
|
||||
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||
|
||||
# Left frame for controls
|
||||
left_frame = ttk.Frame(main_frame)
|
||||
left_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 10))
|
||||
|
||||
# Camera selection
|
||||
ttk.Label(left_frame, text="Select Camera:").pack(pady=5)
|
||||
self.camera_combo = ttk.Combobox(left_frame, textvariable=self.camera_var)
|
||||
self.camera_combo.pack(pady=5)
|
||||
self.camera_combo.bind("<<ComboboxSelected>>", self.on_camera_change)
|
||||
|
||||
# Pose type selection
|
||||
ttk.Label(left_frame, text="Pose Type:").pack(pady=5)
|
||||
ttk.Radiobutton(left_frame, text="Good", variable=self.pose_type_var, value="good", command=self.update_dataset_preview).pack()
|
||||
ttk.Radiobutton(left_frame, text="Bad", variable=self.pose_type_var, value="bad", command=self.update_dataset_preview).pack()
|
||||
|
||||
# Capture button
|
||||
self.capture_btn = ttk.Button(left_frame, text="Capture", command=self.capture_image)
|
||||
self.capture_btn.pack(pady=10)
|
||||
|
||||
# Progress bar
|
||||
self.progress = ttk.Progressbar(left_frame, length=200, maximum=self.max_captures)
|
||||
self.progress.pack(pady=10)
|
||||
|
||||
# Dataset info
|
||||
self.dataset_info = ttk.Label(left_frame, text="")
|
||||
self.dataset_info.pack(pady=10)
|
||||
|
||||
# Right frame for camera feed and dataset preview
|
||||
right_frame = ttk.Frame(main_frame)
|
||||
right_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||
|
||||
# Camera feed
|
||||
self.camera_canvas = tk.Canvas(right_frame, width=640, height=480)
|
||||
self.camera_canvas.pack(pady=10)
|
||||
|
||||
# Dataset preview
|
||||
preview_frame = ttk.Frame(right_frame)
|
||||
preview_frame.pack(fill=tk.BOTH, expand=True)
|
||||
|
||||
self.preview_canvas = tk.Canvas(preview_frame)
|
||||
self.preview_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||
|
||||
scrollbar = ttk.Scrollbar(preview_frame, orient=tk.VERTICAL, command=self.preview_canvas.yview)
|
||||
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||||
|
||||
self.preview_canvas.configure(yscrollcommand=scrollbar.set)
|
||||
self.preview_canvas.bind('<Configure>', lambda e: self.preview_canvas.configure(scrollregion=self.preview_canvas.bbox("all")))
|
||||
|
||||
self.preview_frame = ttk.Frame(self.preview_canvas)
|
||||
self.preview_canvas.create_window((0, 0), window=self.preview_frame, anchor="nw")
|
||||
|
||||
def populate_camera_list(self):
|
||||
camera_list = [f"Camera {i}" for i in self.cameras]
|
||||
self.camera_combo['values'] = camera_list
|
||||
if camera_list:
|
||||
self.camera_combo.current(0)
|
||||
|
||||
def on_camera_change(self, event):
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
camera_index = self.cameras[self.camera_combo.current()]
|
||||
self.cap = cv2.VideoCapture(camera_index)
|
||||
|
||||
def capture_image(self):
|
||||
if not hasattr(self, 'cap') or not self.cap.isOpened():
|
||||
camera_index = self.cameras[self.camera_combo.current()]
|
||||
self.cap = cv2.VideoCapture(camera_index)
|
||||
|
||||
ret, frame = self.cap.read()
|
||||
if ret:
|
||||
pose_type = self.pose_type_var.get()
|
||||
existing_files = glob.glob(os.path.join(self.dataset_path, pose_type, f"{pose_type}_*.jpg"))
|
||||
next_index = len(existing_files)
|
||||
img_name = os.path.join(self.dataset_path, pose_type, f"{pose_type}_{next_index}.jpg")
|
||||
cv2.imwrite(img_name, frame)
|
||||
print(f"{img_name} written!")
|
||||
self.capture_count += 1
|
||||
self.progress['value'] = self.capture_count
|
||||
|
||||
if self.capture_count >= self.max_captures:
|
||||
self.capture_btn['state'] = 'disabled'
|
||||
print("Capture complete!")
|
||||
|
||||
self.update_dataset_info()
|
||||
self.update_dataset_preview()
|
||||
|
||||
def update_feed(self):
|
||||
if hasattr(self, 'cap') and self.cap.isOpened():
|
||||
ret, frame = self.cap.read()
|
||||
if ret:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
photo = ImageTk.PhotoImage(image=Image.fromarray(frame))
|
||||
self.camera_canvas.create_image(0, 0, image=photo, anchor=tk.NW)
|
||||
self.camera_canvas.image = photo
|
||||
self.master.after(10, self.update_feed)
|
||||
|
||||
def update_dataset_info(self):
|
||||
good_count = len(glob.glob(os.path.join(self.dataset_path, 'good', '*.jpg')))
|
||||
bad_count = len(glob.glob(os.path.join(self.dataset_path, 'bad', '*.jpg')))
|
||||
info_text = f"Dataset Info:\nGood Poses: {good_count}\nBad Poses: {bad_count}"
|
||||
self.dataset_info.config(text=info_text)
|
||||
|
||||
def update_dataset_preview(self):
|
||||
pose_type = self.pose_type_var.get()
|
||||
images = glob.glob(os.path.join(self.dataset_path, pose_type, '*.jpg'))
|
||||
images.sort(key=os.path.getmtime, reverse=True)
|
||||
|
||||
# Clear previous preview
|
||||
for widget in self.preview_frame.winfo_children():
|
||||
widget.destroy()
|
||||
|
||||
# Create a grid of images
|
||||
row = 0
|
||||
col = 0
|
||||
for img_path in images:
|
||||
img = Image.open(img_path)
|
||||
img.thumbnail((100, 100))
|
||||
photo = ImageTk.PhotoImage(img)
|
||||
label = ttk.Label(self.preview_frame, image=photo)
|
||||
label.image = photo
|
||||
label.grid(row=row, column=col, padx=5, pady=5)
|
||||
col += 1
|
||||
if col == 5: # 5 images per row
|
||||
col = 0
|
||||
row += 1
|
||||
|
||||
self.preview_canvas.update_idletasks()
|
||||
self.preview_canvas.configure(scrollregion=self.preview_canvas.bbox("all"))
|
||||
|
||||
def run(self):
|
||||
create_directories(self.dataset_path)
|
||||
self.update_feed()
|
||||
self.master.mainloop()
|
||||
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = tk.Tk()
|
||||
app = PostureCaptureApp(root)
|
||||
app.run()
|
38
train.py
Normal file
38
train.py
Normal file
@ -0,0 +1,38 @@
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
|
||||
def train_model(model, train_data, train_labels, val_data, val_labels, datagen):
|
||||
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
|
||||
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)
|
||||
|
||||
# Initial training phase
|
||||
history = model.fit(
|
||||
train_data, train_labels,
|
||||
epochs=50,
|
||||
batch_size=32,
|
||||
validation_data=(val_data, val_labels),
|
||||
callbacks=[early_stopping, reduce_lr]
|
||||
)
|
||||
|
||||
# Check if it's a ResNet model
|
||||
if 'resnet' in model.name.lower():
|
||||
print("Fine-tuning ResNet model...")
|
||||
# Fine-tuning phase for ResNet model
|
||||
base_model = model.layers[0]
|
||||
base_model.trainable = True
|
||||
|
||||
# Freeze first 100 layers
|
||||
for layer in base_model.layers[:100]:
|
||||
layer.trainable = False
|
||||
|
||||
model.compile(optimizer=Adam(1e-5), loss='binary_crossentropy', metrics=['accuracy'])
|
||||
|
||||
history_fine = model.fit(
|
||||
train_data, train_labels,
|
||||
epochs=50,
|
||||
batch_size=32,
|
||||
validation_data=(val_data, val_labels),
|
||||
callbacks=[early_stopping, reduce_lr]
|
||||
)
|
||||
|
||||
return model
|
Loading…
x
Reference in New Issue
Block a user