From edb9e55644ab27994c7d2ba00c7db63d702e6c73 Mon Sep 17 00:00:00 2001 From: tcsenpai Date: Tue, 8 Oct 2024 13:47:07 +0200 Subject: [PATCH] initial transformers text --- data_loader.py | 30 ++++++++++-- main.py | 49 ++++++++++++------- model.py | 70 ++++++-------------------- posture_detector.py | 10 ++-- requirements.txt | 4 +- train.py | 117 +++++++++++++++++++++++++++++++------------- 6 files changed, 164 insertions(+), 116 deletions(-) diff --git a/data_loader.py b/data_loader.py index 579a406..45dd33e 100644 --- a/data_loader.py +++ b/data_loader.py @@ -1,12 +1,30 @@ import os import numpy as np from sklearn.model_selection import train_test_split -from tensorflow.keras.preprocessing.image import load_img, img_to_array +from PIL import Image from colorama import Fore, init # Initialize colorama init(autoreset=True) +def resize_and_crop(img, target_size): + # Resize the image while maintaining aspect ratio + img.thumbnail((target_size[0], target_size[0]), Image.LANCZOS) + + # Get the current size + width, height = img.size + + # Calculate dimensions to crop + left = (width - target_size[0]) // 2 + top = (height - target_size[1]) // 2 + right = left + target_size[0] + bottom = top + target_size[1] + + # Crop the image + img = img.crop((left, top, right, bottom)) + + return img + def load_datasets(dataset_path, img_size=(224, 224)): good_path = os.path.join(dataset_path, 'good') bad_path = os.path.join(dataset_path, 'bad') @@ -18,8 +36,9 @@ def load_datasets(dataset_path, img_size=(224, 224)): good_images = os.listdir(good_path) for img_name in good_images: img_path = os.path.join(good_path, img_name) - img = load_img(img_path, target_size=img_size) - img_array = img_to_array(img) + img = Image.open(img_path).convert('RGB') + img = resize_and_crop(img, img_size) + img_array = np.array(img) data.append(img_array) labels.append(1) @@ -27,8 +46,9 @@ def load_datasets(dataset_path, img_size=(224, 224)): bad_images = os.listdir(bad_path) for img_name in bad_images: img_path = os.path.join(bad_path, img_name) - img = load_img(img_path, target_size=img_size) - img_array = img_to_array(img) + img = Image.open(img_path).convert('RGB') + img = resize_and_crop(img, img_size) + img_array = np.array(img) data.append(img_array) labels.append(0) diff --git a/main.py b/main.py index 70d1e97..bf3ac3a 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,8 @@ import queue import sys import io import subprocess +from transformers import TFViTForImageClassification +import tensorflow as tf class StreamToQueue(io.TextIOBase): def __init__(self, queue): @@ -38,6 +40,7 @@ class PostureDetectionApp: self.dataset_path = os.getenv('DATASET_PATH') self.model_path = os.getenv('MODEL_PATH') + self.setup_ui() self.cameras = self.list_available_cameras() self.populate_camera_list() @@ -67,9 +70,8 @@ class PostureDetectionApp: self.model_entry.insert(0, self.model_path) ttk.Button(left_frame, text="Browse", command=self.browse_model).pack(pady=5) - # ResNet option - self.use_resnet_var = tk.BooleanVar(value=False) - ttk.Checkbutton(left_frame, text="Use ResNet50", variable=self.use_resnet_var).pack(pady=5) + # Model Architecture Label + ttk.Label(left_frame, text="Model Architecture: Vision Transformer").pack(pady=5) # Train button self.train_button = ttk.Button(left_frame, text="Train Model", command=self.train_model) @@ -106,6 +108,10 @@ class PostureDetectionApp: self.dataset_entry.insert(0, path) def browse_model(self): + path = filedialog.asksaveasfilename(defaultextension=".h5", filetypes=[("H5 files", "*.h5")]) + if path: + self.model_entry.delete(0, tk.END) + self.model_entry.insert(0, path) path = filedialog.askopenfilename(filetypes=[("H5 files", "*.h5")]) if path: self.model_entry.delete(0, tk.END) @@ -130,41 +136,37 @@ class PostureDetectionApp: if camera_list: self.camera_combo.current(0) + def toggle_model(self): + if self.use_vit_var.get(): + self.use_vit_var.set(False) + def train_model(self): dataset_path = self.dataset_entry.get() model_path = self.model_entry.get() - use_resnet = self.use_resnet_var.get() - # Disable the train button self.train_button['state'] = 'disabled' - - # Clear the progress text self.progress_text.delete('1.0', tk.END) - # Start the training in a separate thread self.training_thread = threading.Thread(target=self._train_model_thread, - args=(dataset_path, model_path, use_resnet)) + args=(dataset_path, model_path)) self.training_thread.start() - # Start checking the queue for updates self.master.after(100, self._check_training_queue) - def _train_model_thread(self, dataset_path, model_path, use_resnet): - # Redirect stdout to capture print statements + def _train_model_thread(self, dataset_path, model_path): old_stdout = sys.stdout sys.stdout = StreamToQueue(self.training_queue) try: train_data, train_labels, val_data, val_labels = load_datasets(dataset_path) - model = PostureNet(use_resnet=use_resnet) + model = PostureNet() trained_model = train_model(model, train_data, train_labels, val_data, val_labels, None) - trained_model.save(model_path) + trained_model.save_pretrained(model_path) self.training_queue.put(("complete", f"Model saved to {model_path}")) except Exception as e: self.training_queue.put(("error", str(e))) finally: - # Restore stdout sys.stdout = old_stdout def _check_training_queue(self): @@ -205,8 +207,11 @@ class PostureDetectionApp: messagebox.showerror("Error", "No trained model found. Please train the model first.") return - self.trained_model = PostureNet(use_resnet=self.use_resnet_var.get()) - self.trained_model.load_weights(model_path) + try: + self.trained_model = TFViTForImageClassification.from_pretrained(model_path) + except Exception as e: + messagebox.showerror("Error", f"Failed to load the model: {str(e)}") + return camera_index = self.cameras[self.camera_combo.current()] self.cap = cv2.VideoCapture(camera_index) @@ -280,6 +285,16 @@ class PostureDetectionApp: if hasattr(self, 'cap'): self.cap.release() +def detect_posture(frame, model): + preprocessed = model.preprocess_input([frame])['pixel_values'] + outputs = model.model(preprocessed, training=False) + prediction = tf.nn.softmax(outputs.logits) + + if prediction[0][1] > 0.5: + return "Good" + else: + return "Bad" + if __name__ == "__main__": root = tk.Tk() app = PostureDetectionApp(root) diff --git a/model.py b/model.py index 286f3df..3381247 100644 --- a/model.py +++ b/model.py @@ -1,58 +1,20 @@ -from tensorflow.keras.models import Sequential, Model -from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D -from tensorflow.keras.applications import ResNet50 +from transformers import TFViTForImageClassification, ViTConfig, ViTFeatureExtractor -def CustomPostureNet(input_shape=(224, 224, 3)): - model = Sequential([ - Conv2D(32, (3, 3), activation='relu', input_shape=input_shape), - BatchNormalization(), - MaxPooling2D((2, 2)), - Conv2D(64, (3, 3), activation='relu'), - BatchNormalization(), - MaxPooling2D((2, 2)), - Conv2D(128, (3, 3), activation='relu'), - BatchNormalization(), - MaxPooling2D((2, 2)), - Conv2D(128, (3, 3), activation='relu'), - BatchNormalization(), - MaxPooling2D((2, 2)), - Flatten(), - Dense(512, activation='relu'), - BatchNormalization(), - Dropout(0.5), - Dense(256, activation='relu'), - BatchNormalization(), - Dropout(0.5), - Dense(1, activation='sigmoid') - ]) - - model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) - return model +class PostureNet: + def __init__(self): + self.model = None + self.feature_extractor = None -def ResNet50PostureNet(input_shape=(224, 224, 3)): - base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape) - - # Freeze the base model layers - base_model.trainable = False + def build_model(self, num_labels=2): + config = ViTConfig.from_pretrained('google/vit-base-patch16-224', num_labels=num_labels) + self.model = TFViTForImageClassification(config) + self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') + return self.model - # Add custom layers - x = base_model.output - x = GlobalAveragePooling2D()(x) - x = Dense(1024, activation='relu')(x) - x = BatchNormalization()(x) - x = Dropout(0.5)(x) - x = Dense(512, activation='relu')(x) - x = BatchNormalization()(x) - x = Dropout(0.5)(x) - outputs = Dense(1, activation='sigmoid')(x) + def load_model(self, model_path): + self.model = TFViTForImageClassification.from_pretrained(model_path) + self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_path) + return self.model - model = Model(inputs=base_model.input, outputs=outputs) - - model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) - return model - -def PostureNet(use_resnet=False, input_shape=(224, 224, 3)): - if use_resnet: - return ResNet50PostureNet(input_shape) - else: - return CustomPostureNet(input_shape) \ No newline at end of file + def preprocess_input(self, images): + return self.feature_extractor(images=images, return_tensors="tf") diff --git a/posture_detector.py b/posture_detector.py index 53d60ce..a1d7cea 100644 --- a/posture_detector.py +++ b/posture_detector.py @@ -1,17 +1,19 @@ import cv2 import numpy as np +import tensorflow as tf def preprocess_image(frame): - # Resize and normalize the image img = cv2.resize(frame, (224, 224)) img = img.astype('float32') / 255.0 - return np.expand_dims(img, axis=0) + img = np.expand_dims(img, axis=0) + return img def detect_posture(frame, model): preprocessed = preprocess_image(frame) - prediction = model.predict(preprocessed)[0][0] + prediction = model(preprocessed, training=False) + prediction = tf.nn.softmax(prediction.logits) - if prediction > 0.5: + if prediction[0][1] > 0.5: return "Good" else: return "Bad" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8bbcfca..eb0ef51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,6 @@ scikit-learn tensorflow python-dotenv Pillow -colorama \ No newline at end of file +colorama +transformers +torch \ No newline at end of file diff --git a/train.py b/train.py index 4186efe..809c1ce 100644 --- a/train.py +++ b/train.py @@ -1,38 +1,85 @@ -from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau -from tensorflow.keras.optimizers import Adam +import tensorflow as tf +from transformers import create_optimizer +from tensorflow.keras.callbacks import EarlyStopping -def train_model(model, train_data, train_labels, val_data, val_labels, datagen): - early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True) - reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6) - - # Initial training phase - history = model.fit( - train_data, train_labels, - epochs=50, - batch_size=32, - validation_data=(val_data, val_labels), - callbacks=[early_stopping, reduce_lr] - ) - - # Check if it's a ResNet model - if 'resnet' in model.name.lower(): - print("Fine-tuning ResNet model...") - # Fine-tuning phase for ResNet model - base_model = model.layers[0] - base_model.trainable = True - - # Freeze first 100 layers - for layer in base_model.layers[:100]: - layer.trainable = False - - model.compile(optimizer=Adam(1e-5), loss='binary_crossentropy', metrics=['accuracy']) - - history_fine = model.fit( - train_data, train_labels, - epochs=50, - batch_size=32, - validation_data=(val_data, val_labels), - callbacks=[early_stopping, reduce_lr] +def train_model(model, train_data, train_labels, val_data, val_labels, callbacks): + try: + built_model = model.build_model() + + # Build the model by calling it on a batch of data + dummy_input = tf.zeros((1, 224, 224, 3), dtype=tf.float32) + _ = built_model(dummy_input, training=False) + + # Print model summary + built_model.summary() + + # Create optimizer + num_train_steps = len(train_data) // 32 * 50 # assuming batch_size=32 and epochs=50 + optimizer, lr_schedule = create_optimizer( + init_lr=2e-5, + num_train_steps=num_train_steps, + num_warmup_steps=0, + weight_decay_rate=0.01, ) - return model \ No newline at end of file + if callbacks is None: + callbacks = [EarlyStopping(patience=5, restore_best_weights=True)] + + # Preprocess data using the feature extractor + train_data = model.preprocess_input(train_data)['pixel_values'] + val_data = model.preprocess_input(val_data)['pixel_values'] + + # Convert labels to TensorFlow tensors + train_labels = tf.convert_to_tensor(train_labels, dtype=tf.int32) + val_labels = tf.convert_to_tensor(val_labels, dtype=tf.int32) + + # Print shapes after conversion + print(f"Train data shape: {train_data.shape}") + print(f"Train labels shape: {train_labels.shape}") + print(f"Validation data shape: {val_data.shape}") + print(f"Validation labels shape: {val_labels.shape}") + + # Custom training loop + train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels)).batch(32) + val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels)).batch(32) + + for epoch in range(50): # 50 epochs + print(f"Epoch {epoch + 1}/{50}") + for step, (batch_images, batch_labels) in enumerate(train_dataset): + with tf.GradientTape() as tape: + outputs = built_model(batch_images, training=True) + loss = tf.keras.losses.sparse_categorical_crossentropy( + batch_labels, outputs.logits, from_logits=True + ) + loss = tf.reduce_mean(loss) + + grads = tape.gradient(loss, built_model.trainable_variables) + optimizer.apply_gradients(zip(grads, built_model.trainable_variables)) + + if step % 50 == 0: + print(f"Step {step}, Loss: {loss:.4f}") + + # Validation + val_loss = 0 + val_accuracy = 0 + for val_images, val_labels in val_dataset: + val_outputs = built_model(val_images, training=False) + val_loss += tf.reduce_mean( + tf.keras.losses.sparse_categorical_crossentropy( + val_labels, val_outputs.logits, from_logits=True + ) + ) + val_accuracy += tf.reduce_mean( + tf.keras.metrics.sparse_categorical_accuracy(val_labels, val_outputs.logits) + ) + + val_loss /= len(val_dataset) + val_accuracy /= len(val_dataset) + print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}") + + return built_model + + except Exception as e: + print(f"An error occurred during model training: {str(e)}") + print(f"Error type: {type(e).__name__}") + raise \ No newline at end of file