Skip to content

Commit 0682553

Browse files
committedOct 3, 2017
ntm: initial NTM implementation
1 parent 6028b73 commit 0682553

15 files changed

+809
-2
lines changed
 

‎README.md

+48-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,48 @@
1-
# pytorch-ntm
2-
A Pytorch implementation of an NTM (Neural Turing Machine)
1+
# PyTorch Neural Turing Machine (NTM)
2+
3+
PyTorch implementation of [Neural Turing Machines](https://arxiv.org/abs/1410.5401) (NTM).
4+
5+
An **NTM** is a memory augumented neural network (attached to external memory) where the interactions with the external memory (address, read, write) are done using differentiable transformations. Overall, the network is end-to-end differentiable and thus trainable by a gradient based optimizer.
6+
7+
The NTM is processing input in sequences, much like an LSTM, but with additional benfits: (1) The external memory allows the network to learn algorithmic tasks easier (2) Having a larger capacity without increasing the network's trainable parameters.
8+
9+
The external memory allows the NTM to learn algorithmic tasks, that are much harder for LSTM to learn, and to maintain an internal state much longer than traditional LSTMs.
10+
11+
## A PyTorch Implementation
12+
13+
This repository implements a vanilla NTM in a straight forward way. The following architecture is used:
14+
15+
![NTM Architecture](./images/ntm.png)
16+
17+
* Support for batch leanring
18+
* Any read or write head configuration is supported (for example, 5 read heads and 2 write heads), the order of operation is specified by the user
19+
20+
Example of training convergence for the **copy task** using 4 different seeds.
21+
22+
![NTM Convergence](./images/train.png)
23+
24+
The following plot shows the cost per sequence length during training, the network was trained with `seed=7` and shows a fast convergence. Other seeds may not perform as well but should converge in less than 30K iterations.
25+
26+
![NTM Convergence](./images/train2.png)
27+
28+
Here is an animated GIF that shows how the model generalize. The model was evaluated after every 500 training samples, using the target sequence shown in the upper part of the image. The bottom part shows the network output at any given training stage.
29+
30+
![NTM Convergence](./images/train-20.gif)
31+
32+
The following is the same, but with `sequence length = 80`. Note that the network was trained with sequences of lengths 1 to 20.
33+
34+
![NTM Convergence](./images/train-80.gif)
35+
36+
37+
## Installation
38+
39+
The NTM can be used as a reusable module, currently not packaged though.
40+
41+
1. Clone repository
42+
2. Install [PyTorch](http://pytorch.org/)
43+
3. pip install -r requirements.txt
44+
45+
## Usage
46+
47+
> python train.py
48+

‎images/ntm.png

23.8 KB
Loading

‎images/train-20.gif

353 KB
Loading

‎images/train-80.gif

1.15 MB
Loading

‎images/train.png

25.4 KB
Loading

‎images/train2.png

54.1 KB
Loading

‎ntm/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__all__ = ['controller', 'head', 'memory', 'ntm']

‎ntm/controller.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""LSTM Controller."""
2+
import torch
3+
from torch import nn
4+
from torch.nn import Parameter
5+
import numpy as np
6+
7+
8+
class LSTMController(nn.Module):
9+
"""An NTM controller based on LSTM."""
10+
def __init__(self, num_inputs, num_outputs, num_layers):
11+
super(LSTMController, self).__init__()
12+
13+
self.num_inputs = num_inputs
14+
self.num_outputs = num_outputs
15+
self.num_layers = num_layers
16+
17+
self.lstm = nn.LSTM(input_size=num_inputs,
18+
hidden_size=num_outputs,
19+
num_layers=num_layers)
20+
21+
# The hidden state is a learned parameter
22+
self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)
23+
self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)
24+
25+
self.reset_parameters()
26+
27+
def create_new_state(self, batch_size):
28+
# Dimension: (num_layers * num_directions, batch, hidden_size)
29+
lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
30+
lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
31+
return lstm_h, lstm_c
32+
33+
def reset_parameters(self):
34+
for p in self.lstm.parameters():
35+
if p.dim() == 1:
36+
nn.init.constant(p, 0)
37+
else:
38+
stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs))
39+
nn.init.uniform(p, -stdev, stdev)
40+
41+
def size(self):
42+
return self.num_inputs, self.num_outputs
43+
44+
def forward(self, x, prev_state):
45+
x = x.unsqueeze(0)
46+
outp, state = self.lstm(x, prev_state)
47+
return outp.squeeze(0), state

‎ntm/head.py

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""NTM Read/Write Head."""
2+
import torch
3+
from torch import nn
4+
from torch.autograd import Variable
5+
import torch.nn.functional as F
6+
import numpy as np
7+
8+
9+
def _split_cols(mat, lengths):
10+
"""Split a 2D matrix to variable length columns."""
11+
assert mat.size()[1] == sum(lengths), "Lengths must be summed to num columns"
12+
l = np.cumsum([0] + lengths)
13+
results = []
14+
for s, e in zip(l[:-1], l[1:]):
15+
results += [mat[:, s:e]]
16+
return results
17+
18+
19+
class NTMHeadBase(nn.Module):
20+
"""An NTM Read/Write Head."""
21+
22+
def __init__(self, memory, controller_size):
23+
"""Initilize the read/write head.
24+
25+
:param memory: The :class:`NTMMemory` to be addressed by the head.
26+
:param controller_size: The size of the internal representation.
27+
"""
28+
super(NTMHeadBase, self).__init__()
29+
30+
self.memory = memory
31+
self.N, self.M = memory.size()
32+
self.controller_size = controller_size
33+
34+
def create_new_state(self, batch_size):
35+
raise NotImplementedError
36+
37+
def init_weights(self):
38+
raise NotImplementedError
39+
40+
def is_read_head(self):
41+
return NotImplementedError
42+
43+
def _address_memory(self, k, β, g, s, γ, w_prev):
44+
# Activations
45+
k = F.relu(k)
46+
β = F.relu(β)
47+
g = F.sigmoid(g)
48+
s = F.softmax(F.relu(s))
49+
γ = 1 + F.relu(γ)
50+
51+
w = self.memory.address(k, β, g, s, γ, w_prev)
52+
53+
return w
54+
55+
56+
class NTMReadHead(NTMHeadBase):
57+
def __init__(self, memory, controller_size):
58+
super(NTMReadHead, self).__init__(memory, controller_size)
59+
60+
# Corresponding to k, β, g, s, γ sizes from the paper
61+
self.read_lengths = [self.M, 1, 1, 3, 1]
62+
self.fc_read = nn.Linear(controller_size, sum(self.read_lengths))
63+
self.reset_parameters()
64+
65+
def create_new_state(self, batch_size):
66+
# The state holds the previous time step address weightings
67+
return Variable(torch.zeros(batch_size, self.N))
68+
69+
def reset_parameters(self):
70+
# Initialize the linear layers
71+
nn.init.xavier_uniform(self.fc_read.weight, gain=1.4)
72+
nn.init.normal(self.fc_read.bias, std=0.01)
73+
74+
def is_read_head(self):
75+
return True
76+
77+
def forward(self, embeddings, w_prev):
78+
"""NTMReadHead forward function.
79+
80+
:param embeddings: input representation of the controller.
81+
:param w_prev: previous step state
82+
"""
83+
o = self.fc_read(embeddings)
84+
k, β, g, s, γ = _split_cols(o, self.read_lengths)
85+
86+
# Read from memory
87+
w = self._address_memory(k, β, g, s, γ, w_prev)
88+
r = self.memory.read(w)
89+
90+
return r, w
91+
92+
93+
class NTMWriteHead(NTMHeadBase):
94+
def __init__(self, memory, controller_size):
95+
super(NTMWriteHead, self).__init__(memory, controller_size)
96+
97+
# Corresponding to k, β, g, s, γ, e, a sizes from the paper
98+
self.write_lengths = [self.M, 1, 1, 3, 1, self.M, self.M]
99+
self.fc_write = nn.Linear(controller_size, sum(self.write_lengths))
100+
self.reset_parameters()
101+
102+
def create_new_state(self, batch_size):
103+
return Variable(torch.zeros(batch_size, self.N))
104+
105+
def reset_parameters(self):
106+
# Initialize the linear layers
107+
nn.init.xavier_uniform(self.fc_write.weight, gain=1.4)
108+
nn.init.normal(self.fc_write.bias, std=0.01)
109+
110+
def is_read_head(self):
111+
return False
112+
113+
def forward(self, embeddings, w_prev):
114+
"""NTMWriteHead forward function.
115+
116+
:param embeddings: input representation of the controller.
117+
:param w_prev: previous step state
118+
"""
119+
o = self.fc_write(embeddings)
120+
k, β, g, s, γ, e, a = _split_cols(o, self.write_lengths)
121+
122+
# Handle activations
123+
e = F.relu(e)
124+
a = F.relu(a)
125+
126+
# Write to memory
127+
w = self._address_memory(k, β, g, s, γ, w_prev)
128+
self.memory.write(w, e, a)
129+
130+
return w

‎ntm/memory.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""An NTM's memory implementation."""
2+
import torch
3+
from torch.autograd import Variable
4+
import torch.nn.functional as F
5+
from torch import nn
6+
import numpy as np
7+
8+
9+
def _convolve(w, s):
10+
"""Circular convolution implementation."""
11+
assert s.size(0) == 3
12+
t = torch.cat([w[-2:], w, w[:2]])
13+
c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
14+
return c[1:-1]
15+
16+
17+
class NTMMemory(nn.Module):
18+
"""Memory bank for NTM."""
19+
def __init__(self, N, M):
20+
"""Initialize the NTM Memory matrix.
21+
22+
The memory's dimensions are (batch_size x N x M).
23+
Each batch has it's own memory matrix.
24+
25+
:param N: Number of rows in the memory.
26+
:param M: Number of columns/features in the memory.
27+
"""
28+
super(NTMMemory, self).__init__()
29+
30+
self.N = N
31+
self.M = M
32+
33+
# The memory bias allows the heads to learn how to initially address
34+
# memory locations by content
35+
self.mem_bias = Variable(torch.Tensor(N, M))
36+
self.register_buffer('mem_bias', self.mem_bias.data)
37+
38+
# Initialize memory bias
39+
stdev = 1 / (np.sqrt(N + M))
40+
nn.init.uniform(self.mem_bias, -stdev, stdev)
41+
42+
def reset(self, batch_size):
43+
"""Initialize memory from bias, for start-of-sequence."""
44+
self.batch_size = batch_size
45+
self.memory = self.mem_bias.clone().repeat(batch_size, 1, 1)
46+
47+
def size(self):
48+
return self.N, self.M
49+
50+
def read(self, w):
51+
"""Read from memory (according to section 3.1)."""
52+
return torch.matmul(w.unsqueeze(1), self.memory).squeeze(1)
53+
54+
def write(self, w, e, a):
55+
"""write to memory (according to section 3.2)."""
56+
self.prev_mem = self.memory
57+
self.memory = Variable(torch.Tensor(self.batch_size, self.N, self.M))
58+
for b in range(self.batch_size):
59+
erase = torch.ger(w[b], e[b])
60+
add = torch.ger(w[b], a[b])
61+
self.memory[b] = self.prev_mem[b] * (1 - erase) + add
62+
63+
def address(self, k, β, g, s, γ, w_prev):
64+
"""NTM Addressing (according to section 3.3).
65+
66+
Returns a softmax weighting over the rows of the memory matrix.
67+
68+
:param k: The key vector.
69+
:param β: The key strength (focus).
70+
:param g: Scalar interpolation gate (with previous weighting).
71+
:param s: Shift weighting.
72+
:param γ: Sharpen weighting scalar.
73+
:param w_prev: The weighting produced in the previous time step.
74+
"""
75+
# Content focus
76+
wc = self._similarity(k, β)
77+
78+
# Location focus
79+
wg = self._interpolate(w_prev, wc, g)
80+
ŵ = self._shift(wg, s)
81+
w = self._sharpen(ŵ, γ)
82+
83+
return w
84+
85+
def _similarity(self, k, β):
86+
k = k.view(self.batch_size, 1, -1)
87+
w = F.softmax(β * F.cosine_similarity(self.memory + 1e-16, k + 1e-16, dim=-1))
88+
return w
89+
90+
def _interpolate(self, w_prev, wc, g):
91+
return g * wc + (1 - g) * w_prev
92+
93+
def _shift(self, wg, s):
94+
result = Variable(torch.zeros(wg.size()))
95+
for b in range(self.batch_size):
96+
result[b] = _convolve(wg[b], s[b])
97+
return result
98+
99+
def _sharpen(self, ŵ, γ):
100+
w = ŵ ** γ
101+
w = torch.div(w, torch.sum(w, dim=1).view(-1, 1) + 1e-16)
102+
return w

‎ntm/ntm.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python
2+
import torch
3+
from torch import nn
4+
from torch.autograd import Variable
5+
import torch.nn.functional as F
6+
7+
8+
class NTM(nn.Module):
9+
"""A Neural Turing Machine."""
10+
def __init__(self, num_inputs, num_outputs, controller, memory, heads):
11+
"""Initialize the NTM.
12+
13+
:param num_inputs: External input size.
14+
:param num_outputs: External output size.
15+
:param controller: :class:`LSTMController`
16+
:param memory: :class:`NTMMemory`
17+
:param heads: list of :class:`NTMReadHead` or :class:`NTMWriteHead`
18+
19+
Note: This design allows the flexibility of using any number of read and
20+
write heads independently, also, the order by which the heads are
21+
called in controlled by the user (order in list)
22+
"""
23+
super(NTM, self).__init__()
24+
25+
# Save arguments
26+
self.num_inputs = num_inputs
27+
self.num_outputs = num_outputs
28+
self.controller = controller
29+
self.memory = memory
30+
self.heads = heads
31+
32+
self.N, self.M = memory.size()
33+
_, self.controller_size = controller.size()
34+
35+
# Initialize the initial previous read values to random biases
36+
self.num_read_heads = 0
37+
self.init_r = []
38+
for head in heads:
39+
if head.is_read_head():
40+
init_r_bias = Variable(torch.randn(1, self.M) * 0.01)
41+
self.register_buffer("read{}_bias".format(self.num_read_heads), init_r_bias.data)
42+
self.init_r += [init_r_bias]
43+
self.num_read_heads += 1
44+
45+
assert self.num_read_heads > 0, "heads list must contain at least a single read head"
46+
47+
# Initialize a fully connected layer to produce the actual output:
48+
# [controller_output; previous_reads ] -> output
49+
self.fc = nn.Linear(self.controller_size + self.num_read_heads * self.M, num_outputs)
50+
self.reset_parameters()
51+
52+
def create_new_state(self, batch_size):
53+
init_r = [r.clone().repeat(batch_size, 1) for r in self.init_r]
54+
controller_state = self.controller.create_new_state(batch_size)
55+
heads_state = [head.create_new_state(batch_size) for head in self.heads]
56+
57+
return init_r, controller_state, heads_state
58+
59+
def reset_parameters(self):
60+
# Initialize the linear layer
61+
nn.init.xavier_uniform(self.fc.weight, gain=1)
62+
nn.init.normal(self.fc.bias, std=0.01)
63+
64+
def forward(self, x, prev_state):
65+
"""NTM forward function.
66+
67+
:param x: input vector (batch_size x num_inputs)
68+
:param prev_state: The previous state of the NTM
69+
"""
70+
# Unpack the previous state
71+
prev_reads, prev_controller_state, prev_heads_states = prev_state
72+
73+
# Use the controller to get an embeddings
74+
inp = torch.cat([x] + prev_reads, dim=1)
75+
controller_outp, controller_state = self.controller(inp, prev_controller_state)
76+
77+
# Read/Write from the list of heads
78+
reads = []
79+
heads_states = []
80+
for head, prev_head_state in zip(self.heads, prev_heads_states):
81+
if head.is_read_head():
82+
r, head_state = head(controller_outp, prev_head_state)
83+
reads += [r]
84+
else:
85+
head_state = head(controller_outp, prev_head_state)
86+
heads_states += [head_state]
87+
88+
# Generate Output
89+
inp2 = torch.cat([controller_outp] + reads, dim=1)
90+
o = F.sigmoid(self.fc(inp2))
91+
92+
# Pack the current state
93+
state = (reads, controller_state, heads_states)
94+
95+
return o, state

‎requirements.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
argcomplete
2+
attrs
3+
numpy
4+
pytest
5+

‎tasks/__init__.py

Whitespace-only changes.

‎tasks/copytask.py

+240
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
"""Copy Task NTM model."""
2+
from attr import attrs, attrib, Factory
3+
import torch
4+
from torch import nn
5+
from torch.autograd import Variable
6+
from torch import optim
7+
import random
8+
import numpy as np
9+
10+
from ntm.controller import LSTMController
11+
from ntm.memory import NTMMemory
12+
from ntm.head import NTMReadHead, NTMWriteHead
13+
from ntm.ntm import NTM
14+
15+
16+
# Encapsulation of the various NTM components for the Copy task
17+
class CopyTaskNTM(nn.Module):
18+
19+
def __init__(self, num_inputs, num_outputs, controller_size, controller_layers, N, M):
20+
"""Initialize an CopyTaskNTM.
21+
22+
:param num_inputs: External number of inputs.
23+
:param num_outputs: External number of outputs.
24+
:param controller_size: The size of the internal representation.
25+
:param controller_layers: Number of controller layers.
26+
:param N: Number of rows in the memory bank.
27+
:param M: Number of cols/features in the memory bank.
28+
"""
29+
super(CopyTaskNTM, self).__init__()
30+
31+
# Save args
32+
self.num_inputs = num_inputs
33+
self.num_outputs = num_outputs
34+
self.controller_size = controller_size
35+
self.N = N
36+
self.M = M
37+
38+
# Create the NTM components
39+
memory = NTMMemory(N, M)
40+
heads = nn.ModuleList([
41+
NTMReadHead(memory, controller_size),
42+
NTMWriteHead(memory, controller_size)
43+
])
44+
controller = LSTMController(num_inputs + M, controller_size, controller_layers)
45+
self.ntm = NTM(num_inputs, num_outputs, controller, memory, heads)
46+
self.memory = memory
47+
48+
def init_sequence(self, batch_size):
49+
"""Initializing the state."""
50+
self.batch_size = batch_size
51+
self.memory.reset(batch_size)
52+
self.previous_state = self.ntm.create_new_state(batch_size)
53+
54+
def forward(self, x=None):
55+
if x is None:
56+
x = Variable(torch.zeros(self.batch_size, self.num_inputs))
57+
58+
o, self.previous_state = self.ntm(x, self.previous_state)
59+
return o, self.previous_state
60+
61+
def calculate_num_params(self):
62+
"""Returns the total number of parameters."""
63+
num_params = 0
64+
for p in self.parameters():
65+
num_params += p.data.view(-1).size(0)
66+
return num_params
67+
68+
69+
# Generator of randomized test sequences
70+
def dataloader(num_batches,
71+
batch_size,
72+
seq_width,
73+
min_len,
74+
max_len):
75+
"""Generator of random sequences for the copy task.
76+
77+
Creates random batches of "bits" sequences.
78+
79+
All the sequences within each batch have the same length.
80+
The length is [`min_len`, `max_len`]
81+
82+
:param num_batches: Total number of batches to generate.
83+
:param seq_width: The width of each item in the sequence.
84+
:param batch_size: Batch size.
85+
:param min_len: Sequence minimum length.
86+
:param max_len: Sequence maximum length.
87+
88+
NOTE: The input width is `seq_width + 1`, the additional input
89+
contain the delimiter.
90+
"""
91+
for batch_num in range(num_batches):
92+
93+
# All batches have the same sequence length
94+
seq_len = random.randint(min_len, max_len)
95+
seq = np.random.binomial(1, 0.5, (seq_len, batch_size, seq_width))
96+
seq = Variable(torch.from_numpy(seq))
97+
98+
# The input includes an additional channel used for the delimiter
99+
inp = Variable(torch.zeros(seq_len + 1, batch_size, seq_width + 1))
100+
inp[:seq_len, :, :seq_width] = seq
101+
inp[seq_len, :, seq_width] = 1.0 # delimiter in our control channel
102+
outp = seq.clone()
103+
104+
yield batch_num+1, inp.float(), outp.float()
105+
106+
107+
def train_batch(net, criterion, optimizer, X, Y):
108+
"""Trains a single batch."""
109+
optimizer.zero_grad()
110+
seq_len, batch_size, _ = Y.size()
111+
112+
# New sequence
113+
net.init_sequence(batch_size)
114+
115+
# Feed the sequence + delimiter
116+
for i in range(seq_len+1):
117+
net(X[i])
118+
119+
# Read the output (no input given)
120+
y_out = Variable(torch.zeros(Y.size()))
121+
for i in range(seq_len):
122+
y_out[i], _ = net()
123+
124+
loss = criterion(y_out, Y)
125+
loss.backward()
126+
optimizer.step()
127+
128+
y_out_binarized = y_out.clone().data
129+
y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)
130+
131+
# The cost is the number of error bits per sequence
132+
cost = torch.sum(torch.abs(y_out_binarized - Y.data))
133+
134+
return loss.data[0], cost / batch_size
135+
136+
137+
def evaluate(net, criterion, X, Y):
138+
"""Evaluate a single batch (without training)."""
139+
seq_len, batch_size, _ = Y.size()
140+
141+
# New sequence
142+
net.init_sequence(batch_size)
143+
144+
# Feed the sequence + delimiter
145+
for i in range(seq_len+1):
146+
o, _ = net(X[i])
147+
148+
# Read the output (no input given)
149+
states = []
150+
y_out = Variable(torch.zeros(Y.size()))
151+
for i in range(seq_len):
152+
y_out[i], state = net()
153+
states += [state]
154+
155+
loss = criterion(y_out, Y)
156+
157+
y_out_binarized = y_out.clone().data
158+
y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)
159+
160+
# The cost is the number of error bits per sequence
161+
cost = torch.sum(torch.abs(y_out_binarized - Y.data))
162+
163+
result = {
164+
'loss': loss.data[0],
165+
'cost': cost / batch_size,
166+
'y_out': y_out,
167+
'y_out_binarized': y_out_binarized,
168+
'states': states
169+
}
170+
171+
return result
172+
173+
174+
@attrs
175+
class CopyTaskParams(object):
176+
name = attrib(default="copy-task")
177+
controller_size = attrib(default=100)
178+
controller_layers = attrib(default=1)
179+
sequence_width = attrib(default=8)
180+
sequence_min_len = attrib(default=1)
181+
sequence_max_len = attrib(default=20)
182+
memory_n = attrib(default=128)
183+
memory_m = attrib(default=20)
184+
num_batches = attrib(default=50000)
185+
batch_size = attrib(default=1)
186+
rmsprop_lr = attrib(default=1e-4)
187+
rmsprop_momentum = attrib(default=0.9)
188+
rmsprop_alpha = attrib(default=0.95)
189+
190+
191+
#
192+
# To create a network simply instantiate the `:class:CopyTaskModelTraining`,
193+
# all the components will be wired with the default values.
194+
# In case you'd like to change any of defaults, do the following:
195+
#
196+
# > params = CopyTaskParams(batch_size=4)
197+
# > model = CopyTaskModelTraining(params=params)
198+
#
199+
# Then use `model.net`, `model.optimizer` and `model.criterion` to train the
200+
# network. Call `model.train_batch` for training and `model.evaluate`
201+
# for evaluating.
202+
#
203+
# You may skip this alltogether, and use `:class:CopyTaskNTM` directly.
204+
#
205+
206+
@attrs
207+
class CopyTaskModelTraining(object):
208+
params = attrib(default=Factory(CopyTaskParams))
209+
train_batch = attrib(default=train_batch)
210+
evaluate = attrib(default=evaluate)
211+
net = attrib()
212+
dataloader = attrib()
213+
criterion = attrib()
214+
optimizer = attrib()
215+
216+
@net.default
217+
def default_net(self):
218+
# We have 1 additional input for the delimiter which is passed on a
219+
# separate "control" channel
220+
net = CopyTaskNTM(self.params.sequence_width + 1, self.params.sequence_width,
221+
self.params.controller_size, self.params.controller_layers,
222+
self.params.memory_n, self.params.memory_m)
223+
return net
224+
225+
@dataloader.default
226+
def default_dataloader(self):
227+
return dataloader(self.params.num_batches, self.params.batch_size,
228+
self.params.sequence_width,
229+
self.params.sequence_min_len, self.params.sequence_max_len)
230+
231+
@criterion.default
232+
def default_criterion(self):
233+
return nn.BCELoss()
234+
235+
@optimizer.default
236+
def default_optimizer(self):
237+
return optim.RMSprop(self.net.parameters(),
238+
momentum=self.params.rmsprop_momentum,
239+
alpha=self.params.rmsprop_alpha,
240+
lr=self.params.rmsprop_lr)

‎train.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#!/usr/bin/env python
2+
# PYTHON_ARGCOMPLETE_OK
3+
"""Training for the Copy Task in Neural Turing Machines."""
4+
5+
import argparse
6+
import json
7+
import time
8+
import random
9+
10+
import argcomplete
11+
import torch
12+
import numpy as np
13+
14+
from tasks.copytask import CopyTaskModelTraining
15+
16+
17+
# Default values for program arguments
18+
RANDOM_SEED = 1000
19+
REPORT_INTERVAL = 200
20+
CHECKPOINT_INTERVAL = 1000
21+
22+
23+
def get_ms():
24+
"""Returns the current time in miliseconds."""
25+
return time.time() * 1000
26+
27+
28+
def init_seed(seed=None):
29+
"""Seed the RNGs for predicatability/reproduction purposes."""
30+
if seed is None:
31+
seed = int(get_ms() // 1000)
32+
33+
print("Using seed={}".format(seed))
34+
np.random.seed(seed)
35+
torch.manual_seed(seed)
36+
random.seed(seed)
37+
38+
39+
def progress_clean():
40+
"""Clean the progress bar."""
41+
print("\r{}".format(" " * 80), end='\r')
42+
43+
44+
def progress_bar(batch_num, report_interval, last_loss):
45+
"""Prints the progress until the next report."""
46+
progress = (((batch_num-1) % report_interval) + 1) / report_interval
47+
fill = int(progress * 40)
48+
print("\r[{}{}]: {} (Loss: {:.4f})".format(
49+
"=" * fill, " " * (40 - fill), batch_num, last_loss), end='')
50+
51+
52+
def save_checkpoint(net, name, seed, batch_num, losses, costs, seq_lengths):
53+
progress_clean()
54+
55+
basename = "./{}-{}-batch-{}".format(name, seed, batch_num)
56+
model_fname = basename + ".model"
57+
print("Saving model checkpoint to: '{}'".format(model_fname))
58+
torch.save(net.state_dict(), model_fname)
59+
60+
# Save the training history
61+
train_fname = basename + ".json"
62+
print("Saving model training history to '{}'".format(train_fname))
63+
content = {
64+
"loss": losses,
65+
"cost": costs,
66+
"seq_lengths": seq_lengths
67+
}
68+
open(train_fname, 'wt').write(json.dumps(content))
69+
70+
71+
def train_model(model,
72+
args):
73+
74+
num_batches = model.params.num_batches
75+
batch_size = model.params.batch_size
76+
77+
print("Training model for {} batches (batch_size={})...".format(
78+
num_batches, batch_size))
79+
80+
losses = []
81+
costs = []
82+
seq_lengths = []
83+
start_ms = get_ms()
84+
85+
for batch_num, x, y in model.dataloader:
86+
loss, cost = model.train_batch(model.net, model.criterion, model.optimizer, x, y)
87+
losses += [loss]
88+
costs += [cost]
89+
seq_lengths += [y.size(0)]
90+
91+
# Update the progress bar
92+
progress_bar(batch_num, args.report_interval, loss)
93+
94+
# Report
95+
if batch_num % args.report_interval == 0:
96+
mean_loss = np.array(losses[-args.report_interval:]).mean()
97+
mean_cost = np.array(costs[-args.report_interval:]).mean()
98+
mean_time = int(((get_ms() - start_ms) / args.report_interval) / batch_size)
99+
progress_clean()
100+
print("Batch {} Loss: {:.6f} Cost: {:.2f} Time: {} ms/sequence".format(
101+
batch_num, mean_loss, mean_cost, mean_time))
102+
start_ms = get_ms()
103+
104+
# Checkpoint
105+
if (args.checkpoint_interval != 0) and (batch_num % args.checkpoint_interval == 0):
106+
save_checkpoint(model.net, model.params.name, args.seed,
107+
batch_num, losses, costs, seq_lengths)
108+
109+
print("Done training.")
110+
111+
112+
def init_arguments():
113+
parser = argparse.ArgumentParser(prog='train.py')
114+
parser.add_argument('--seed', type=int, default=RANDOM_SEED, help="Seed value for RNGs")
115+
parser.add_argument('--checkpoint_interval', type=int, default=CHECKPOINT_INTERVAL,
116+
help="Checkpoint interval (in batches). 0 - disable")
117+
parser.add_argument('--report_interval', type=int, default=REPORT_INTERVAL,
118+
help="Report interval (in batches)")
119+
120+
argcomplete.autocomplete(parser)
121+
122+
args = parser.parse_args()
123+
return args
124+
125+
126+
def main():
127+
# Initialize arguments
128+
args = init_arguments()
129+
130+
# Initialize random
131+
init_seed(args.seed)
132+
133+
# Initialize the model
134+
model = CopyTaskModelTraining()
135+
136+
print("Total number of parameters: {}".format(model.net.calculate_num_params()))
137+
train_model(model, args)
138+
139+
140+
if __name__ == '__main__':
141+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.