-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathAttentionNMT.py
194 lines (168 loc) · 9.34 KB
/
AttentionNMT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import tensorflow as tf
HIDDEN_SIZE = 1024
DECODER_LAYERS = 2 # Layers of LSTM (single layer of bidirectional LSTM)
SRC_VOCAB_SIZE = 10000 # Source Vocabulary size
TRG_VOCAB_SIZE = 4000 # Target Vocabulary size
BATCH_SIZE = 100 # Training batch size
KEEP_PROB = 0.8 # Probability of node not be dropout
MAX_GRAD_NORM = 5 # Maxumum of gradient limit
SHARE_EMB_AND_SOFTMAX = True # Share weights with softmax and embedding layer
# ID of <sos> and <eos> in vocabulary table
# In the decode process we'll need <sos> as the first input
# and check whether the sentence reach <eos>
SOS_ID = 1
EOS_ID = 2
class AttentionNMTModel(object):
def __init__(self):
# Define Encoder and decoder
self.enc_cell_fw = tf.nn.rnn_cell.LSTMCell(HIDDEN_SIZE, name='basic_lstm_cell')
self.enc_cell_bw = tf.nn.rnn_cell.LSTMCell(HIDDEN_SIZE, name='basic_lstm_cell')
self.dec_cell = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.LSTMCell(HIDDEN_SIZE, name='basic_lstm_cell')
for _ in range(DECODER_LAYERS)])
# Embedding of source and target language
self.src_embedding = tf.get_variable(
"src_emb", [SRC_VOCAB_SIZE, HIDDEN_SIZE])
self.trg_embedding = tf.get_variable(
"trg_emb", [TRG_VOCAB_SIZE, HIDDEN_SIZE])
# Weights of softmax layer
if SHARE_EMB_AND_SOFTMAX:
self.softmax_weight = tf.transpose(self.trg_embedding)
else:
self.softmax_weight = tf.get_variable(
"weight", [HIDDEN_SIZE, TRG_VOCAB_SIZE])
self.softmax_bias = tf.get_variable(
"softmax_bias", [TRG_VOCAB_SIZE])
# Define compute graph in forward propgation
def forward(self, src_input, src_size, trg_input, trg_label, trg_size):
batch_size = tf.shape(src_input)[0]
# Transfer input and output words to embedding
src_emb = tf.nn.embedding_lookup(self.src_embedding, src_input)
trg_emb = tf.nn.embedding_lookup(self.trg_embedding, trg_input)
# Dropout embedding
src_emb = tf.nn.dropout(src_emb, KEEP_PROB)
trg_emb = tf.nn.dropout(trg_emb, KEEP_PROB)
# Construct encoder
# Encoder read embeddings in every position and output the enc_state of last state
# Encoder is a double layer LSTM
# thus enc_state contain two LSTMStateTuple class, each for each layer
# enc_output is the output of the top layer LSTM
# which has shape of [batch_size, max_time, HIDDEN_SIZE]
with tf.variable_scope("encoder"):
# When we construct bidirectional RNN encoder
# the output of bidirectional RNN is a tuple contains two tensor
# each tensor has shape [batch_size, max_time, HIDDEN_SIZE]
# represent each step of each LSTM
enc_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
self.enc_cell_fw, self.enc_cell_bw, src_emb, src_size,
dtype=tf.float32)
# Concatenate two output as one tensor
enc_outputs = tf.concat([enc_outputs[0], enc_outputs[1]], -1)
# Construct decoder
# Decoder read embeddings in every position and output the dec_state
# for every output of last layer LSTM
# Output dimension of dec_outputs is [batch_size, max_time, HIDDEN_SIZE]
with tf.variable_scope("decoder"):
# Select calculation model of attention weights
# BahdanauAttention is a FCNN with one hidden layer
# memory_sequence_length is a tensor with shape [batch_size]
# it means the sentences length in each batch
# Attention will need this information to set weight to 0
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
HIDDEN_SIZE, enc_outputs,
memory_sequence_length=src_size)
# Incapsulate RNN and attention model as a single higher level RNN
attention_cell = tf.contrib.seq2seq.AttentionWrapper(
self.dec_cell, attention_mechanism,
attention_layer_size=HIDDEN_SIZE)
# Construct decoder by using attention_cell and dynamic_rnn
# we didn't set init_state
# i.e. we didn't use output of decoder as initial status of input
# we obtain information entirely depend on attention
dec_outputs, _ = tf.nn.dynamic_rnn(
attention_cell, trg_emb, trg_size, dtype=tf.float32)
# Calculate log perplexity of decoder
output = tf.reshape(dec_outputs, [-1, HIDDEN_SIZE])
logits = tf.matmul(output, self.softmax_weight) + self.softmax_bias
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.reshape(trg_label, [-1]), logits=logits)
# When calculate average loss, we need to set weights to 0
# to prevent interfere of prediction caused by illegal position
label_weights = tf.sequence_mask(
trg_size, maxlen=tf.shape(trg_label)[1], dtype=tf.float32)
label_weights = tf.reshape(label_weights, [-1])
cost = tf.reduce_sum(loss * label_weights)
cost_per_token = cost / tf.reduce_sum(label_weights)
# Define backprop
trainable_variables = tf.trainable_variables()
grads = tf.gradients(cost / tf.to_float(batch_size),
trainable_variables)
grads, _ = tf.clip_by_global_norm(grads, MAX_GRAD_NORM)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = optimizer.apply_gradients(
zip(grads, trainable_variables))
return cost_per_token, train_op
def inference(self, src_input):
# Although we'll only inference one sentence, but dynamic_rnn require
# input to be a batch, so we reshape it to [1, sentence length]
src_size = tf.convert_to_tensor([len(src_input)], dtype=tf.int32)
src_input = tf.convert_to_tensor([src_input], dtype=tf.int32)
src_emb = tf.nn.embedding_lookup(self.src_embedding, src_input)
# Use bidirectional_dynamic_rnn construct Encoder
# (this is the same one as forward)
with tf.variable_scope("encoder"):
enc_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
self.enc_cell_fw, self.enc_cell_bw, src_emb, src_size,
dtype=tf.float32)
enc_outputs = tf.concat([enc_outputs[0], enc_outputs[1]], -1)
with tf.variable_scope("decoder"):
# Define attention mechanism of the decoder
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
HIDDEN_SIZE, enc_outputs,
memory_sequence_length=src_size)
# Concatenate self.dec_cell and attention together
attention_cell = tf.contrib.seq2seq.AttentionWrapper(
self.dec_cell, attention_mechanism,
attention_layer_size=HIDDEN_SIZE)
# Set maximum decode steps
# to prevent from infinity loop in extreme situation
MAX_DEC_LEN=100
with tf.variable_scope("decoder/rnn/attention_wrapper"):
# Use a dynamic size TensorArray to store generated sentence
init_array = tf.TensorArray(dtype=tf.int32, size=0,
dynamic_size=True, clear_after_read=False)
# Insert first word <sos> as the input of decoder
init_array = init_array.write(0, SOS_ID)
# Call attention_cell.zero_state to construct initial recurrent state
# Recurrent state include hidden state of RNN, TensorArray which
# store the generated sentence and a integer to record decode step
init_loop_var = (
attention_cell.zero_state(batch_size=1, dtype=tf.float32),
init_array, 0)
# Loop condition of tf.while_loop:
# Recurrent until decode <eos> or reach the maximum steps
def continue_loop_condition(state, trg_ids, step):
return tf.reduce_all(tf.logical_and(
tf.not_equal(trg_ids.read(step), EOS_ID),
tf.less(step, MAX_DEC_LEN-1)))
def loop_body(state, trg_ids, step):
# Read the last output and get its embedding
trg_input = [trg_ids.read(step)]
trg_emb = tf.nn.embedding_lookup(self.trg_embedding,
trg_input)
# Use attention_cell to calculate one forward step
dec_outputs, next_state = attention_cell.call(
state=state, inputs=trg_emb)
# Calculate every possible words' logit
# and pick word with the maximum logist as the output os this step
output = tf.reshape(dec_outputs, [-1, HIDDEN_SIZE])
logits = (tf.matmul(output, self.softmax_weight)
+ self.softmax_bias)
next_id = tf.argmax(logits, axis=1, output_type=tf.int32)
# Write this word into trg_ids of recurrent state
trg_ids = trg_ids.write(step+1, next_id[0])
return next_state, trg_ids, step+1
# Execute tf.while_loop until return final state
state, trg_ids, step = tf.while_loop(
continue_loop_condition, loop_body, init_loop_var)
return trg_ids.stack()