This repository was archived by the owner on Nov 17, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_flow_network.py
executable file
·65 lines (55 loc) · 2.16 KB
/
train_flow_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Generic imports
import time
import numpy as np
# Custom imports
from params import *
from datasets_utils import *
from networks_utils import *
import networks_utils
# Load images
imgs, n_imgs, height, width, n_channels = load_img_dataset(input_dir,
downscaling,
color)
# Load solutions
sols, n_sols, height, width, n_channels = load_img_dataset(sol_dir,
downscaling,
color)
# Split data into training, validation and testing sets
(imgs_train,
imgs_valid,
imgs_tests) = split_dataset(imgs, train_size, valid_size, tests_size)
(sols_train,
sols_valid,
sols_tests) = split_dataset(sols, train_size, valid_size, tests_size)
# Print informations
print('Training set size is', imgs_train.shape[0])
print('Validation set size is', imgs_valid.shape[0])
print('Test set size is', imgs_tests.shape[0])
print('Input images downscaled to',str(width)+'x'+str(height))
# Set the network and train it
start = time.time()
regression = getattr(networks_utils, network)
model, train_model = regression(imgs_train,
sols_train,
imgs_valid,
sols_valid,
imgs_tests,
n_filters_initial,
kernel_size,
kernel_transpose_size,
pool_size,
stride_size,
learning_rate,
batch_size,
n_epochs,
height,
width,
n_channels)
end = time.time()
print('Training time was ',end-start,' seconds')
# Evaluate score on test set
score = evaluate_model_score(model, imgs_tests, sols_tests)
# Save model
save_keras_model(model)
# Plot accuracy and loss
plot_accuracy_and_loss(train_model)