-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
51 lines (40 loc) · 1.79 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from keras.losses import SparseCategoricalCrossentropy
from keras.optimizers import Adam
from keras.metrics import SparseCategoricalAccuracy
from ade_dataset import *
from keras.callbacks import ModelCheckpoint
from models.structrual.u_net import UNet
import matplotlib.pyplot as plt
def main():
model = UNet(train_gen.input_shape, classes=150, )
# model = SegFormerB0(input_shape, num_classes=150, attention_drop_rate=0.2, drop_rate=0.1)
model.compile(
loss=SparseCategoricalCrossentropy(),
optimizer=Adam(learning_rate=0.001),
metrics=SparseCategoricalAccuracy()
)
checkpoint = ModelCheckpoint(filepath=f'./checkpoint/{model.name}_ade_weights.h5',
monitor='val_sparse_categorical_accuracy',
save_best_only=True, save_weights_only=True, mode='auto')
print(model.summary())
history = model.fit(train_gen, batch_size=32, epochs=20, validation_data=val_gen, callbacks=[checkpoint])
val_loss, val_acc = model.evaluate(val_gen)
print(f"model test loss: {val_loss}, test accuracy: {val_acc}")
train_acc, train_loss = history.history['sparse_categorical_accuracy'], history.history['loss']
val_acc, val_loss = history.history['val_sparse_categorical_accuracy'], history.history['val_loss']
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_acc, color='purple')
plt.plot(val_acc, color='red')
plt.xlabel('$epochs$')
plt.ylabel('$accuracy$')
plt.legend(['train_accuracy', 'val_accuracy'])
plt.subplot(1, 2, 2)
plt.plot(train_loss, color='deeppink')
plt.plot(val_loss, color='deepskyblue')
plt.xlabel('$epochs$')
plt.ylabel('$loss$')
plt.legend(['train_loss', 'val_loss'])
plt.show()
if __name__ == '__main__':
main()