mirror of
https://github.com/tcsenpai/poser.git
synced 2025-06-03 01:40:07 +00:00
58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
from tensorflow.keras.models import Sequential, Model
|
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
|
|
from tensorflow.keras.applications import ResNet50
|
|
|
|
def CustomPostureNet(input_shape=(224, 224, 3)):
|
|
model = Sequential([
|
|
Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
|
|
BatchNormalization(),
|
|
MaxPooling2D((2, 2)),
|
|
Conv2D(64, (3, 3), activation='relu'),
|
|
BatchNormalization(),
|
|
MaxPooling2D((2, 2)),
|
|
Conv2D(128, (3, 3), activation='relu'),
|
|
BatchNormalization(),
|
|
MaxPooling2D((2, 2)),
|
|
Conv2D(128, (3, 3), activation='relu'),
|
|
BatchNormalization(),
|
|
MaxPooling2D((2, 2)),
|
|
Flatten(),
|
|
Dense(512, activation='relu'),
|
|
BatchNormalization(),
|
|
Dropout(0.5),
|
|
Dense(256, activation='relu'),
|
|
BatchNormalization(),
|
|
Dropout(0.5),
|
|
Dense(1, activation='sigmoid')
|
|
])
|
|
|
|
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
|
return model
|
|
|
|
def ResNet50PostureNet(input_shape=(224, 224, 3)):
|
|
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
|
|
|
|
# Freeze the base model layers
|
|
base_model.trainable = False
|
|
|
|
# Add custom layers
|
|
x = base_model.output
|
|
x = GlobalAveragePooling2D()(x)
|
|
x = Dense(1024, activation='relu')(x)
|
|
x = BatchNormalization()(x)
|
|
x = Dropout(0.5)(x)
|
|
x = Dense(512, activation='relu')(x)
|
|
x = BatchNormalization()(x)
|
|
x = Dropout(0.5)(x)
|
|
outputs = Dense(1, activation='sigmoid')(x)
|
|
|
|
model = Model(inputs=base_model.input, outputs=outputs)
|
|
|
|
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
|
return model
|
|
|
|
def PostureNet(use_resnet=False, input_shape=(224, 224, 3)):
|
|
if use_resnet:
|
|
return ResNet50PostureNet(input_shape)
|
|
else:
|
|
return CustomPostureNet(input_shape) |