5
5
import mnist
6
6
7
7
from neural_network import NeuralNetwork
8
- from preprocessing import *
9
8
from training import stochastic_gradient_descent
10
-
11
- NUM_EXAMPLES = 59999
9
+ from preprocessing import *
12
10
13
11
14
- def test_mnist_one_hot (num_train_examples = - 1 , num_test_examples = - 1 , hidden_layers = (24 , 32 ), sigmoid = 'tanh' ,
15
- learning_rate = 0.01 , learning_decay = 1.0 , momentum = 0.0 , batch_size = 100 , num_epochs = 100 ,
12
+ def test_mnist_one_hot (num_train_examples = - 1 , num_test_examples = - 1 , hidden_layers = (100 , ), sigmoid = 'tanh' ,
13
+ learning_rate = 0.01 , layer_decay = 1.0 , momentum = 0.0 , batch_size = 100 , num_epochs = 100 ,
16
14
csv_filename = None , return_test_accuracies = False ):
17
- layer_sizes = (784 ,) + hidden_layers + (10 ,)
18
- weight_decay = 0.0
19
-
20
- print ('Network Parameters' )
21
- print ('layer_sizes: {}, sigmoid: {}, weight_decay: {}' .format (layer_sizes , sigmoid , weight_decay ))
22
-
23
- # Set the training parameters.
24
- num_iterations = (NUM_EXAMPLES // batch_size ) * num_epochs
25
-
26
- print ('Training Parameters' )
27
- print ('num_iterations: {}, learning_rate: {}, learning_decay: {}, momentum: {}, batch_size: {}' .format (
28
- num_iterations , learning_rate , learning_decay , momentum , batch_size ))
29
-
30
- print ('' )
31
-
32
15
# Collect and preprocess the data.
33
16
if sigmoid == 'logistic' :
34
17
train_input = convert_mnist_images_logistic (mnist .train_images ()[:num_train_examples ])
@@ -46,9 +29,12 @@ def test_mnist_one_hot(num_train_examples=-1, num_test_examples=-1, hidden_layer
46
29
raise ValueError ('Invalid sigmoid function.' )
47
30
48
31
# Create and train the neural network.
32
+ layer_sizes = (784 ,) + hidden_layers + (10 ,)
33
+ weight_decay = 0.0
49
34
nn = NeuralNetwork (layer_sizes , sigmoid = sigmoid , weight_decay = weight_decay )
50
35
51
36
num_examples = train_input .shape [0 ]
37
+ num_iterations = (num_examples // batch_size ) * num_epochs
52
38
53
39
rows = None
54
40
if csv_filename is not None :
@@ -61,23 +47,31 @@ def test_mnist_one_hot(num_train_examples=-1, num_test_examples=-1, hidden_layer
61
47
def callback (iteration ):
62
48
if iteration % (num_examples // batch_size ) == 0 :
63
49
epoch = iteration // (num_examples // batch_size )
64
- training_prediction_rate = get_prediction_rate (nn , train_input , train_output )
65
- test_prediction_rate = get_prediction_rate (nn , test_input , test_output )
50
+ training_prediction_accuracy = get_prediction_accuracy (nn , train_input , train_output )
51
+ test_prediction_accuracy = get_prediction_accuracy (nn , test_input , test_output )
66
52
training_loss = nn .get_loss (train_input , train_output )
67
53
test_loss = nn .get_loss (test_input , test_output )
68
- print ('{},{:.6f},{:.6f},{:.6f},{:.6f}' .format (epoch , training_prediction_rate , test_prediction_rate ,
54
+ print ('{},{:.6f},{:.6f},{:.6f},{:.6f}' .format (epoch , training_prediction_accuracy , test_prediction_accuracy ,
69
55
training_loss , test_loss ))
70
56
if csv_filename is not None :
71
- rows .append ((epoch , training_prediction_rate , test_prediction_rate , training_loss , test_loss ))
57
+ rows .append ((epoch , training_prediction_accuracy , test_prediction_accuracy , training_loss , test_loss ))
72
58
if return_test_accuracies :
73
- test_accuracies .append (test_prediction_rate )
59
+ test_accuracies .append (test_prediction_accuracy )
60
+
61
+ print ('Network Parameters' )
62
+ print ('layer_sizes: {}, sigmoid: {}, weight_decay: {}' .format (layer_sizes , sigmoid , weight_decay ))
63
+ print ('Training Parameters' )
64
+ print ('num_iterations: {}, learning_rate: {}, layer_decay: {}, momentum: {}, batch_size: {}' .format (
65
+ num_iterations , learning_rate , layer_decay , momentum , batch_size ))
66
+ print ('' )
74
67
75
68
header = 'epoch,training_accuracy,test_accuracy,training_loss,test_loss'
76
69
print (header )
77
70
stochastic_gradient_descent (nn , train_input , train_output , num_iterations = num_iterations ,
78
- learning_rate = learning_rate , learning_decay = learning_decay ,
71
+ learning_rate = learning_rate , layer_decay = layer_decay ,
79
72
momentum = momentum , batch_size = batch_size ,
80
73
callback = callback )
74
+
81
75
if csv_filename is not None :
82
76
save_rows_to_csv (csv_filename , rows , header .split (',' ))
83
77
0 commit comments