-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfindLR_CLR.py
More file actions
150 lines (115 loc) · 4.85 KB
/
findLR_CLR.py
File metadata and controls
150 lines (115 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 4 12:06:24 2018
@author: Steve O'Hagan
References
Blog post: jeremyjordan.me/nn-learning-rate
Original paper: https://arxiv.org/abs/1506.01186
"""
from keras import backend as K
from keras.callbacks import Callback
import numpy as np
import matplotlib.pyplot as plt
class CLRScheduler(Callback):
def __init__(self, min_lr=1e-5, max_lr=1e-2, steps_per_epoch=None, epochs=None):
super().__init__()
self.min_lr = min_lr
self.max_lr = max_lr
self.total_iterations = steps_per_epoch * epochs
self.iteration = 0
self.history = {}
self.stepsize = steps_per_epoch * 2
def clr(self):
cycle = np.floor(1 + self.iteration/(2 * self.stepsize))
x = np.abs(self.iteration/self.stepsize - 2 * cycle + 1)
lr = self.min_lr + (self.max_lr - self.min_lr) * np.maximum(0, (1-x))
return lr
def on_train_begin(self, logs=None):
'''Initialize the learning rate to the minimum value at the start of training.'''
logs = logs or {}
K.set_value(self.model.optimizer.lr, self.min_lr)
def on_batch_end(self, epoch, logs=None):
'''Record previous batch statistics and update the learning rate.'''
logs = logs or {}
self.iteration += 1
self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
self.history.setdefault('iterations', []).append(self.iteration)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
K.set_value(self.model.optimizer.lr, self.clr())
def plot_lr(self):
'''Helper function to quickly inspect the learning rate schedule.'''
plt.plot(self.history['iterations'], self.history['lr'])
#plt.yscale('log')
plt.xlabel('Iteration')
plt.ylabel('Learning rate')
plt.show()
def plot_loss(self):
'''Helper function to quickly observe the learning rate experiment results.'''
plt.plot(self.history['lr'], self.history['loss'])
plt.xscale('log')
plt.xlabel('Learning rate')
plt.ylabel('Loss')
plt.show()
class LRFinder(Callback):
'''
A simple callback for finding the optimal learning rate range for your model + dataset.
# Usage
```python
lr_finder = LRFinder(min_lr=1e-5,
max_lr=1e-2,
steps_per_epoch=np.ceil(epoch_size/batch_size),
epochs=3)
model.fit(X_train, Y_train, callbacks=[lr_finder])
lr_finder.plot_loss()
```
# Arguments
min_lr: The lower bound of the learning rate range for the experiment.
max_lr: The upper bound of the learning rate range for the experiment.
steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`.
epochs: Number of epochs to run experiment. Usually between 2 and 4 epochs is sufficient.
# References
Blog post: jeremyjordan.me/nn-learning-rate
Original paper: https://arxiv.org/abs/1506.01186
'''
def __init__(self, min_lr=1e-5, max_lr=1e-2, steps_per_epoch=None, epochs=None):
super().__init__()
self.min_lr = min_lr
self.max_lr = max_lr
self.total_iterations = steps_per_epoch * epochs
self.iteration = 0
self.history = {}
def clr(self):
'''Calculate the learning rate.'''
x = self.iteration / self.total_iterations
return self.min_lr + (self.max_lr-self.min_lr) * x
def on_train_begin(self, logs=None):
'''Initialize the learning rate to the minimum value at the start of training.'''
logs = logs or {}
K.set_value(self.model.optimizer.lr, self.min_lr)
def on_batch_end(self, epoch, logs=None):
'''Record previous batch statistics and update the learning rate.'''
logs = logs or {}
self.iteration += 1
self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
self.history.setdefault('iterations', []).append(self.iteration)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
K.set_value(self.model.optimizer.lr, self.clr())
def plot_lr(self):
'''Helper function to quickly inspect the learning rate schedule.'''
plt.plot(self.history['iterations'], self.history['lr'])
plt.yscale('log')
plt.xlabel('Iteration')
plt.ylabel('Learning rate')
plt.show()
def plot_loss(self):
'''Helper function to quickly observe the learning rate experiment results.'''
plt.plot(self.history['lr'], self.history['loss'])
plt.xscale('log')
plt.xlabel('Learning rate')
plt.ylabel('Loss')
plt.show()
#%%
if __name__ == "__main__":
pass