-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtraining_plot.py
executable file
·121 lines (92 loc) · 3.34 KB
/
training_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/python
import argparse
import json
import re
import sys
import matplotlib.pyplot as plot
import numpy as np
""" Utility for plotting data collected during training runs. """
def plot_training(loss, testing_acc, iter_step):
""" Plots the training results.
Args:
loss: The list of loss data.
testing_acc: The list of testing accuracy data.
iter_step: How many iterations elapsed between each logging interval. """
# Make three subplots that share the same time axis.
fig, time_axes = plot.subplots(2, sharex=True)
testing_acc = np.repeat(testing_acc, len(loss) / len(testing_acc))
# Cut the last bit so they're the same shape.
loss = loss[:len(testing_acc)]
# Compute x values.
x_values = range(0, len(loss))
x_values = np.multiply(x_values, iter_step)
# Plot everything.
time_axes[0].plot(x_values, loss)
time_axes[1].plot(x_values, testing_acc)
# One x label at the bottom.
time_axes[1].set_xlabel("Iterations")
# Y lables for each graph.
time_axes[0].set_ylabel("Loss")
time_axes[1].set_ylabel("Testing Accuracy")
fig.tight_layout()
plot.show()
def average_filter(data, window):
""" Uses a sliding-window average filter to reduce data noise. """
if window == 1:
# No averaging.
return data
averaged = []
for i in range(0, len(data) - window):
sample = np.mean(data[i:(i + window)])
averaged.append(sample)
return averaged
def load_log(log_file):
""" Parses data from a log file instead of the JSON dump.
Args:
log_file: The file to load data from.
Returns:
List of the testing_loss, training loss, testing accuracy, and
training accuracy. """
testing_loss = []
training_loss = []
testing_acc = []
training_acc = []
lines = log_file.read().split("\n")
for line in lines:
if "Training loss" in line:
# This line contains the training loss and accuracy.
numbers = re.findall("\d\.\d+", line)
loss, acc = [float(num) for num in numbers]
training_loss.append(loss)
training_acc.append(acc)
if "Testing loss" in line:
# This line contains the testing loss and accuracy.
numbers = re.findall("\d\.\d+", line)
loss, acc = [float(num) for num in numbers]
testing_loss.append(loss)
testing_acc.append(acc)
return testing_loss, training_loss, testing_acc, training_acc
def main():
parser = argparse.ArgumentParser("Analyze training data logs.")
parser.add_argument("data_file", help="The data file to analyze.")
parser.add_argument("-l", "--log_file", action="store_true",
help="Analyze log file instead of JSON dump.")
parser.add_argument("-i", "--interval", default=1, type=int,
help="Number of iterations between training logs.")
parser.add_argument("-f", "--filter_interval", default=1, type=int,
help="Window size for average filtering.")
args = parser.parse_args()
# Load the logged data.
log_file = file(args.data_file)
if not args.log_file:
loss, test_acc, _ = json.load(log_file)
else:
# Parse from the log file.
_, loss, test_acc, _ = load_log(log_file)
log_file.close()
# Average filtering.
loss = average_filter(loss, args.filter_interval)
test_acc = average_filter(test_acc, args.filter_interval)
plot_training(loss, test_acc, args.interval)
if __name__ == "__main__":
main()