mirror of
https://github.com/tcsenpai/poser.git
synced 2025-06-06 11:15:22 +00:00
85 lines
3.5 KiB
Python
85 lines
3.5 KiB
Python
import tensorflow as tf
|
|
from transformers import create_optimizer
|
|
from tensorflow.keras.callbacks import EarlyStopping
|
|
|
|
def train_model(model, train_data, train_labels, val_data, val_labels, callbacks):
|
|
try:
|
|
built_model = model.build_model()
|
|
|
|
# Build the model by calling it on a batch of data
|
|
dummy_input = tf.zeros((1, 224, 224, 3), dtype=tf.float32)
|
|
_ = built_model(dummy_input, training=False)
|
|
|
|
# Print model summary
|
|
built_model.summary()
|
|
|
|
# Create optimizer
|
|
num_train_steps = len(train_data) // 32 * 50 # assuming batch_size=32 and epochs=50
|
|
optimizer, lr_schedule = create_optimizer(
|
|
init_lr=2e-5,
|
|
num_train_steps=num_train_steps,
|
|
num_warmup_steps=0,
|
|
weight_decay_rate=0.01,
|
|
)
|
|
|
|
if callbacks is None:
|
|
callbacks = [EarlyStopping(patience=5, restore_best_weights=True)]
|
|
|
|
# Preprocess data using the feature extractor
|
|
train_data = model.preprocess_input(train_data)['pixel_values']
|
|
val_data = model.preprocess_input(val_data)['pixel_values']
|
|
|
|
# Convert labels to TensorFlow tensors
|
|
train_labels = tf.convert_to_tensor(train_labels, dtype=tf.int32)
|
|
val_labels = tf.convert_to_tensor(val_labels, dtype=tf.int32)
|
|
|
|
# Print shapes after conversion
|
|
print(f"Train data shape: {train_data.shape}")
|
|
print(f"Train labels shape: {train_labels.shape}")
|
|
print(f"Validation data shape: {val_data.shape}")
|
|
print(f"Validation labels shape: {val_labels.shape}")
|
|
|
|
# Custom training loop
|
|
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels)).batch(32)
|
|
val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels)).batch(32)
|
|
|
|
for epoch in range(50): # 50 epochs
|
|
print(f"Epoch {epoch + 1}/{50}")
|
|
for step, (batch_images, batch_labels) in enumerate(train_dataset):
|
|
with tf.GradientTape() as tape:
|
|
outputs = built_model(batch_images, training=True)
|
|
loss = tf.keras.losses.sparse_categorical_crossentropy(
|
|
batch_labels, outputs.logits, from_logits=True
|
|
)
|
|
loss = tf.reduce_mean(loss)
|
|
|
|
grads = tape.gradient(loss, built_model.trainable_variables)
|
|
optimizer.apply_gradients(zip(grads, built_model.trainable_variables))
|
|
|
|
if step % 50 == 0:
|
|
print(f"Step {step}, Loss: {loss:.4f}")
|
|
|
|
# Validation
|
|
val_loss = 0
|
|
val_accuracy = 0
|
|
for val_images, val_labels in val_dataset:
|
|
val_outputs = built_model(val_images, training=False)
|
|
val_loss += tf.reduce_mean(
|
|
tf.keras.losses.sparse_categorical_crossentropy(
|
|
val_labels, val_outputs.logits, from_logits=True
|
|
)
|
|
)
|
|
val_accuracy += tf.reduce_mean(
|
|
tf.keras.metrics.sparse_categorical_accuracy(val_labels, val_outputs.logits)
|
|
)
|
|
|
|
val_loss /= len(val_dataset)
|
|
val_accuracy /= len(val_dataset)
|
|
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
|
|
|
|
return built_model
|
|
|
|
except Exception as e:
|
|
print(f"An error occurred during model training: {str(e)}")
|
|
print(f"Error type: {type(e).__name__}")
|
|
raise |