poser/train.py
2024-10-08 13:47:07 +02:00

85 lines
3.5 KiB
Python

import tensorflow as tf
from transformers import create_optimizer
from tensorflow.keras.callbacks import EarlyStopping
def train_model(model, train_data, train_labels, val_data, val_labels, callbacks):
try:
built_model = model.build_model()
# Build the model by calling it on a batch of data
dummy_input = tf.zeros((1, 224, 224, 3), dtype=tf.float32)
_ = built_model(dummy_input, training=False)
# Print model summary
built_model.summary()
# Create optimizer
num_train_steps = len(train_data) // 32 * 50 # assuming batch_size=32 and epochs=50
optimizer, lr_schedule = create_optimizer(
init_lr=2e-5,
num_train_steps=num_train_steps,
num_warmup_steps=0,
weight_decay_rate=0.01,
)
if callbacks is None:
callbacks = [EarlyStopping(patience=5, restore_best_weights=True)]
# Preprocess data using the feature extractor
train_data = model.preprocess_input(train_data)['pixel_values']
val_data = model.preprocess_input(val_data)['pixel_values']
# Convert labels to TensorFlow tensors
train_labels = tf.convert_to_tensor(train_labels, dtype=tf.int32)
val_labels = tf.convert_to_tensor(val_labels, dtype=tf.int32)
# Print shapes after conversion
print(f"Train data shape: {train_data.shape}")
print(f"Train labels shape: {train_labels.shape}")
print(f"Validation data shape: {val_data.shape}")
print(f"Validation labels shape: {val_labels.shape}")
# Custom training loop
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels)).batch(32)
val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels)).batch(32)
for epoch in range(50): # 50 epochs
print(f"Epoch {epoch + 1}/{50}")
for step, (batch_images, batch_labels) in enumerate(train_dataset):
with tf.GradientTape() as tape:
outputs = built_model(batch_images, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
batch_labels, outputs.logits, from_logits=True
)
loss = tf.reduce_mean(loss)
grads = tape.gradient(loss, built_model.trainable_variables)
optimizer.apply_gradients(zip(grads, built_model.trainable_variables))
if step % 50 == 0:
print(f"Step {step}, Loss: {loss:.4f}")
# Validation
val_loss = 0
val_accuracy = 0
for val_images, val_labels in val_dataset:
val_outputs = built_model(val_images, training=False)
val_loss += tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(
val_labels, val_outputs.logits, from_logits=True
)
)
val_accuracy += tf.reduce_mean(
tf.keras.metrics.sparse_categorical_accuracy(val_labels, val_outputs.logits)
)
val_loss /= len(val_dataset)
val_accuracy /= len(val_dataset)
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
return built_model
except Exception as e:
print(f"An error occurred during model training: {str(e)}")
print(f"Error type: {type(e).__name__}")
raise