-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathtext_han.py
355 lines (323 loc) · 25.5 KB
/
text_han.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import tensorflow as tf
import numpy as np
class TextHAN(object):
"""
A Hierarchical Attention Network for text classification/regression.
Uses an embedding layer, followed by a Word Encoder, Word Attention, Sentence Encoder, Sentence Attetion, fully-connected (and softmax) layer.
"""
def __init__(
self, model_type, sequence_length, num_sentences, num_classes, vocab_size,
embedding_size, hidden_size, batch_size, l2_reg_lambda=0.5): # batch_size,
# parameters
self.sequence_length = sequence_length
self.num_sentences = num_sentences
self.vocab_size = vocab_size
self.embed_size = embedding_size
self.hidden_size = hidden_size
self.batch_size = batch_size
self.initializer = tf.random_normal_initializer(stddev=0.1)
# Placeholders for input, output and dropout
self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")
self.input_y = tf.placeholder(tf.float32, [None, num_classes], name="input_y")
self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
self.learning_rate = tf.placeholder(tf.float32, name="learning_rate")
self.sequence_length = int(self.sequence_length / self.num_sentences) # TODO
# Keeping track of l2 regularization loss (optional)
l2_loss = tf.constant(0.0)
# Initialize weights
self.instantiate_weights()
# Create a han
with tf.name_scope('han'):
# 1.Word Encoder
# 1.1 embedding of words
input_x = tf.split(self.input_x, self.num_sentences, axis=1) # a list. length:num_sentences.each element is:[None,self.sequence_length/num_sentences]
input_x = tf.stack(input_x, axis=1) # shape:[None,self.num_sentences,self.sequence_length/num_sentences]
self.embedded_words = tf.nn.embedding_lookup(self.W, input_x) # [None,num_sentences,sentence_length,embed_size]
embedded_words_reshaped = tf.reshape(self.embedded_words, shape=[-1, self.sequence_length, self.embed_size]) # [batch_size*num_sentences,sentence_length,embed_size]
# 1.2 forward gru
hidden_state_forward_list = self.gru_forward_word_level(embedded_words_reshaped) # a list,length is sentence_length, each element is [batch_size*num_sentences,hidden_size]
# 1.3 backward gru
hidden_state_backward_list = self.gru_backward_word_level(embedded_words_reshaped) # a list,length is sentence_length, each element is [batch_size*num_sentences,hidden_size]
# 1.4 concat forward hidden state and backward hidden state. hidden_state: a list.len:sentence_length,element:[batch_size*num_sentences,hidden_size*2]
self.hidden_state = [tf.concat([h_forward, h_backward], axis=1) for h_forward, h_backward in
zip(hidden_state_forward_list, hidden_state_backward_list)] # hidden_state:list,len:sentence_length,element:[batch_size*num_sentences,hidden_size*2]
# 2.Word Attention
# for each sentence.
sentence_representation = self.attention_word_level(self.hidden_state) # output:[batch_size*num_sentences,hidden_size*2]
sentence_representation = tf.reshape(sentence_representation, shape=[-1, self.num_sentences, self.hidden_size * 2]) # shape:[batch_size,num_sentences,hidden_size*2]
#with tf.name_scope("dropout"):#TODO
# sentence_representation = tf.nn.dropout(sentence_representation,keep_prob=self.dropout_keep_prob) # shape:[None,hidden_size*4]
# 3.Sentence Encoder
# 3.1) forward gru for sentence
hidden_state_forward_sentences = self.gru_forward_sentence_level(sentence_representation) # a list.length is sentence_length, each element is [None,hidden_size]
# 3.2) backward gru for sentence
hidden_state_backward_sentences = self.gru_backward_sentence_level(sentence_representation) # a list,length is sentence_length, each element is [None,hidden_size]
# 3.3) concat forward hidden state and backward hidden state
# below hidden_state_sentence is a list,len:sentence_length,element:[None,hidden_size*2]
self.hidden_state_sentence = [tf.concat([h_forward, h_backward], axis=1) for h_forward, h_backward in zip(hidden_state_forward_sentences, hidden_state_backward_sentences)]
# 4.Sentence Attention
document_representation = self.attention_sentence_level(self.hidden_state_sentence) # shape:[None,hidden_size*4]
self.output = document_representation
# Add dropout
with tf.name_scope("dropout"):
self.rnn_drop = tf.nn.dropout(self.output, self.dropout_keep_prob)
# Final (unnormalized) scores and predictions
with tf.name_scope("output"):
W = tf.get_variable(
"W",
shape=[hidden_size*4, num_classes],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")
l2_loss += tf.nn.l2_loss(W)
l2_loss += tf.nn.l2_loss(b)
self.scores = tf.nn.xw_plus_b(self.rnn_drop, W, b, name="scores")
if model_type == 'clf':
self.predictions = tf.argmax(self.scores, 1, name="predictions")
elif model_type == 'reg':
self.predictions = tf.reduce_max(self.scores, 1, name="predictions")
self.predictions = tf.expand_dims(self.predictions, -1)
# Calculate mean cross-entropy loss, or root-mean-square error loss
with tf.name_scope("loss"):
if model_type == 'clf':
losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.input_y)
self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss
elif model_type == 'reg':
losses = tf.sqrt(tf.losses.mean_squared_error(predictions=self.predictions, labels=self.input_y))
self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss
# Accuracy
with tf.name_scope("accuracy"):
if model_type == 'clf':
correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")
elif model_type == 'reg':
self.accuracy = tf.constant(0.0, name="accuracy")
def instantiate_weights(self):
"""define all weights here"""
# Embedding layer
with tf.name_scope("embedding"):
# When trainable parameter equals True the embedding vector is non-static, otherwise is static
self.W = tf.Variable(
tf.random_uniform([self.vocab_size, self.embed_size], -1.0, 1.0),
name="W", trainable=True)
with tf.name_scope("gru_weights_word_level"):
self.W_z = tf.get_variable("W_z", shape=[self.embed_size, self.hidden_size], initializer=self.initializer)
self.U_z = tf.get_variable("U_z", shape=[self.embed_size, self.hidden_size], initializer=self.initializer)
self.b_z = tf.get_variable("b_z", shape=[self.hidden_size])
# GRU parameters:reset gate related
self.W_r = tf.get_variable("W_r", shape=[self.embed_size, self.hidden_size], initializer=self.initializer)
self.U_r = tf.get_variable("U_r", shape=[self.embed_size, self.hidden_size], initializer=self.initializer)
self.b_r = tf.get_variable("b_r", shape=[self.hidden_size])
self.W_h = tf.get_variable("W_h", shape=[self.embed_size, self.hidden_size], initializer=self.initializer)
self.U_h = tf.get_variable("U_h", shape=[self.embed_size, self.hidden_size], initializer=self.initializer)
self.b_h = tf.get_variable("b_h", shape=[self.hidden_size])
with tf.name_scope("gru_weights_sentence_level"):
self.W_z_sentence = tf.get_variable("W_z_sentence", shape=[self.hidden_size * 2, self.hidden_size * 2],
initializer=self.initializer)
self.U_z_sentence = tf.get_variable("U_z_sentence", shape=[self.hidden_size * 2, self.hidden_size * 2],
initializer=self.initializer)
self.b_z_sentence = tf.get_variable("b_z_sentence", shape=[self.hidden_size * 2])
# GRU parameters:reset gate related
self.W_r_sentence = tf.get_variable("W_r_sentence", shape=[self.hidden_size * 2, self.hidden_size * 2],
initializer=self.initializer)
self.U_r_sentence = tf.get_variable("U_r_sentence", shape=[self.hidden_size * 2, self.hidden_size * 2],
initializer=self.initializer)
self.b_r_sentence = tf.get_variable("b_r_sentence", shape=[self.hidden_size * 2])
self.W_h_sentence = tf.get_variable("W_h_sentence", shape=[self.hidden_size * 2, self.hidden_size * 2],
initializer=self.initializer)
self.U_h_sentence = tf.get_variable("U_h_sentence", shape=[self.hidden_size * 2, self.hidden_size * 2],
initializer=self.initializer)
self.b_h_sentence = tf.get_variable("b_h_sentence", shape=[self.hidden_size * 2])
with tf.name_scope("attention"):
self.W_w_attention_word = tf.get_variable("W_w_attention_word",
shape=[self.hidden_size * 2, self.hidden_size * 2],
initializer=self.initializer)
self.W_b_attention_word = tf.get_variable("W_b_attention_word", shape=[self.hidden_size * 2])
self.W_w_attention_sentence = tf.get_variable("W_w_attention_sentence",
shape=[self.hidden_size * 4, self.hidden_size * 2],
initializer=self.initializer)
self.W_b_attention_sentence = tf.get_variable("W_b_attention_sentence", shape=[self.hidden_size * 2])
self.context_vecotor_word = tf.get_variable("what_is_the_informative_word", shape=[self.hidden_size * 2],
initializer=self.initializer) # TODO o.k to use batch_size in first demension?
self.context_vecotor_sentence = tf.get_variable("what_is_the_informative_sentence",
shape=[self.hidden_size * 2], initializer=self.initializer)
def gru_single_step_word_level(self, Xt, h_t_minus_1):
"""
single step of gru for word level
:param Xt: Xt:[batch_size*num_sentences,embed_size]
:param h_t_minus_1:[batch_size*num_sentences,embed_size]
:return:
"""
# update gate: decides how much past information is kept and how much new information is added.
z_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_z) + tf.matmul(h_t_minus_1,
self.U_z) + self.b_z) # z_t:[batch_size*num_sentences,self.hidden_size]
# reset gate: controls how much the past state contributes to the candidate state.
r_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_r) + tf.matmul(h_t_minus_1,
self.U_r) + self.b_r) # r_t:[batch_size*num_sentences,self.hidden_size]
# candiate state h_t~
h_t_candiate = tf.nn.tanh(tf.matmul(Xt, self.W_h) +r_t * (tf.matmul(h_t_minus_1, self.U_h)) + self.b_h) # h_t_candiate:[batch_size*num_sentences,self.hidden_size]
# new state: a linear combine of pervious hidden state and the current new state h_t~
h_t = (1 - z_t) * h_t_minus_1 + z_t * h_t_candiate # h_t:[batch_size*num_sentences,hidden_size]
return h_t
def gru_single_step_sentence_level(self, Xt,
h_t_minus_1): # Xt:[batch_size, hidden_size*2]; h_t:[batch_size, hidden_size*2]
"""
single step of gru for sentence level
:param Xt:[batch_size, hidden_size*2]
:param h_t_minus_1:[batch_size, hidden_size*2]
:return:h_t:[batch_size,hidden_size]
"""
# update gate: decides how much past information is kept and how much new information is added.
z_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_z_sentence) + tf.matmul(h_t_minus_1,
self.U_z_sentence) + self.b_z_sentence) # z_t:[batch_size,self.hidden_size]
# reset gate: controls how much the past state contributes to the candidate state.
r_t = tf.nn.sigmoid(tf.matmul(Xt, self.W_r_sentence) + tf.matmul(h_t_minus_1,
self.U_r_sentence) + self.b_r_sentence) # r_t:[batch_size,self.hidden_size]
# candiate state h_t~
h_t_candiate = tf.nn.tanh(tf.matmul(Xt, self.W_h_sentence) + r_t * (
tf.matmul(h_t_minus_1, self.U_h_sentence)) + self.b_h_sentence) # h_t_candiate:[batch_size,self.hidden_size]
# new state: a linear combine of pervious hidden state and the current new state h_t~
h_t = (1 - z_t) * h_t_minus_1 + z_t * h_t_candiate
return h_t
# forward gru for first level: word levels
def gru_forward_word_level(self, embedded_words):
"""
:param embedded_words:[batch_size*num_sentences,sentence_length,embed_size]
:return:forward hidden state: a list.length is sentence_length, each element is [batch_size*num_sentences,hidden_size]
"""
# split embedded_words
embedded_words_splitted = tf.split(embedded_words, self.sequence_length,
axis=1) # it is a list,length is sentence_length, each element is [batch_size*num_sentences,1,embed_size]
embedded_words_squeeze = [tf.squeeze(x, axis=1) for x in
embedded_words_splitted] # it is a list,length is sentence_length, each element is [batch_size*num_sentences,embed_size]
# demension_1=embedded_words_squeeze[0].get_shape().dims[0]
h_t = tf.ones((self.batch_size * self.num_sentences,
self.hidden_size)) #TODO self.hidden_size h_t =int(tf.get_shape(embedded_words_squeeze[0])[0]) # tf.ones([self.batch_size*self.num_sentences, self.hidden_size]) # [batch_size*num_sentences,embed_size]
h_t_forward_list = []
for time_step, Xt in enumerate(embedded_words_squeeze): # Xt: [batch_size*num_sentences,embed_size]
h_t = self.gru_single_step_word_level(Xt,h_t) # [batch_size*num_sentences,embed_size]<------Xt:[batch_size*num_sentences,embed_size];h_t:[batch_size*num_sentences,embed_size]
h_t_forward_list.append(h_t)
return h_t_forward_list # a list,length is sentence_length, each element is [batch_size*num_sentences,hidden_size]
# backward gru for first level: word level
def gru_backward_word_level(self, embedded_words):
"""
:param embedded_words:[batch_size*num_sentences,sentence_length,embed_size]
:return: backward hidden state:a list.length is sentence_length, each element is [batch_size*num_sentences,hidden_size]
"""
# split embedded_words
embedded_words_splitted = tf.split(embedded_words, self.sequence_length,
axis=1) # it is a list,length is sentence_length, each element is [batch_size*num_sentences,1,embed_size]
embedded_words_squeeze = [tf.squeeze(x, axis=1) for x in
embedded_words_splitted] # it is a list,length is sentence_length, each element is [batch_size*num_sentences,embed_size]
embedded_words_squeeze.reverse() # it is a list,length is sentence_length, each element is [batch_size*num_sentences,embed_size]
# demension_1=int(tf.get_shape(embedded_words_squeeze[0])[0]) #h_t = tf.ones([self.batch_size*self.num_sentences, self.hidden_size])
h_t = tf.ones((self.batch_size * self.num_sentences, self.hidden_size))
h_t_backward_list = []
for time_step, Xt in enumerate(embedded_words_squeeze):
h_t = self.gru_single_step_word_level(Xt, h_t)
h_t_backward_list.append(h_t)
h_t_backward_list.reverse() #ADD 2017.06.14
return h_t_backward_list
# forward gru for second level: sentence level
def gru_forward_sentence_level(self, sentence_representation):
"""
:param sentence_representation: [batch_size,num_sentences,hidden_size*2]
:return:forward hidden state: a list,length is num_sentences, each element is [batch_size,hidden_size]
"""
# split embedded_words
sentence_representation_splitted = tf.split(sentence_representation, self.num_sentences,
axis=1) # it is a list.length is num_sentences,each element is [batch_size,1,hidden_size*2]
sentence_representation_squeeze = [tf.squeeze(x, axis=1) for x in
sentence_representation_splitted] # it is a list.length is num_sentences,each element is [batch_size, hidden_size*2]
# demension_1 = int(tf.get_shape(sentence_representation_squeeze[0])[0]) #scalar: batch_size
h_t = tf.ones((self.batch_size, self.hidden_size * 2)) # TODO
h_t_forward_list = []
for time_step, Xt in enumerate(sentence_representation_squeeze): # Xt:[batch_size, hidden_size*2]
h_t = self.gru_single_step_sentence_level(Xt,
h_t) # h_t:[batch_size,hidden_size]<---------Xt:[batch_size, hidden_size*2]; h_t:[batch_size, hidden_size*2]
h_t_forward_list.append(h_t)
return h_t_forward_list # a list,length is num_sentences, each element is [batch_size,hidden_size]
# backward gru for second level: sentence level
def gru_backward_sentence_level(self, sentence_representation):
"""
:param sentence_representation: [batch_size,num_sentences,hidden_size*2]
:return:forward hidden state: a list,length is num_sentences, each element is [batch_size,hidden_size]
"""
# split embedded_words
sentence_representation_splitted = tf.split(sentence_representation, self.num_sentences,
axis=1) # it is a list.length is num_sentences,each element is [batch_size,1,hidden_size*2]
sentence_representation_squeeze = [tf.squeeze(x, axis=1) for x in
sentence_representation_splitted] # it is a list.length is num_sentences,each element is [batch_size, hidden_size*2]
sentence_representation_squeeze.reverse()
# demension_1 = int(tf.get_shape(sentence_representation_squeeze[0])[0]) # scalar: batch_size
h_t = tf.ones((self.batch_size, self.hidden_size * 2))
h_t_forward_list = []
for time_step, Xt in enumerate(sentence_representation_squeeze): # Xt:[batch_size, hidden_size*2]
h_t = self.gru_single_step_sentence_level(Xt,h_t) # h_t:[batch_size,hidden_size]<---------Xt:[batch_size, hidden_size*2]; h_t:[batch_size, hidden_size*2]
h_t_forward_list.append(h_t)
h_t_forward_list.reverse() #ADD 2017.06.14
return h_t_forward_list # a list,length is num_sentences, each element is [batch_size,hidden_size]
def attention_word_level(self, hidden_state):
"""
input1:self.hidden_state: hidden_state:list,len:sentence_length,element:[batch_size*num_sentences,hidden_size*2]
input2:sentence level context vector:[batch_size*num_sentences,hidden_size*2]
:return:representation.shape:[batch_size*num_sentences,hidden_size*2]
"""
hidden_state_ = tf.stack(hidden_state, axis=1) # shape:[batch_size*num_sentences,sequence_length,hidden_size*2]
# 0) one layer of feed forward network
hidden_state_2 = tf.reshape(hidden_state_, shape=[-1,
self.hidden_size * 2]) # shape:[batch_size*num_sentences*sequence_length,hidden_size*2]
# hidden_state_:[batch_size*num_sentences*sequence_length,hidden_size*2];W_w_attention_sentence:[,hidden_size*2,,hidden_size*2]
hidden_representation = tf.nn.tanh(tf.matmul(hidden_state_2,
self.W_w_attention_word) + self.W_b_attention_word) # shape:[batch_size*num_sentences*sequence_length,hidden_size*2]
hidden_representation = tf.reshape(hidden_representation, shape=[-1, self.sequence_length,
self.hidden_size * 2]) # shape:[batch_size*num_sentences,sequence_length,hidden_size*2]
# attention process:1.get logits for each word in the sentence. 2.get possibility distribution for each word in the sentence. 3.get weighted sum for the sentence as sentence representation.
# 1) get logits for each word in the sentence.
hidden_state_context_similiarity = tf.multiply(hidden_representation,
self.context_vecotor_word) # shape:[batch_size*num_sentences,sequence_length,hidden_size*2]
attention_logits = tf.reduce_sum(hidden_state_context_similiarity,
axis=2) # shape:[batch_size*num_sentences,sequence_length]
# subtract max for numerical stability (softmax is shift invariant). tf.reduce_max:Computes the maximum of elements across dimensions of a tensor.
attention_logits_max = tf.reduce_max(attention_logits, axis=1,
keep_dims=True) # shape:[batch_size*num_sentences,1]
# 2) get possibility distribution for each word in the sentence.
p_attention = tf.nn.softmax(
attention_logits - attention_logits_max) # shape:[batch_size*num_sentences,sequence_length]
# 3) get weighted hidden state by attention vector
p_attention_expanded = tf.expand_dims(p_attention, axis=2) # shape:[batch_size*num_sentences,sequence_length,1]
# below sentence_representation'shape:[batch_size*num_sentences,sequence_length,hidden_size*2]<----p_attention_expanded:[batch_size*num_sentences,sequence_length,1];hidden_state_:[batch_size*num_sentences,sequence_length,hidden_size*2]
sentence_representation = tf.multiply(p_attention_expanded,
hidden_state_) # shape:[batch_size*num_sentences,sequence_length,hidden_size*2]
sentence_representation = tf.reduce_sum(sentence_representation,
axis=1) # shape:[batch_size*num_sentences,hidden_size*2]
return sentence_representation # shape:[batch_size*num_sentences,hidden_size*2]
def attention_sentence_level(self, hidden_state_sentence):
"""
input1: hidden_state_sentence: a list,len:num_sentence,element:[None,hidden_size*4]
input2: sentence level context vector:[self.hidden_size*2]
:return:representation.shape:[None,hidden_size*4]
"""
hidden_state_ = tf.stack(hidden_state_sentence, axis=1) # shape:[None,num_sentence,hidden_size*4]
# 0) one layer of feed forward
hidden_state_2 = tf.reshape(hidden_state_,
shape=[-1, self.hidden_size * 4]) # [None*num_sentence,hidden_size*4]
hidden_representation = tf.nn.tanh(tf.matmul(hidden_state_2,
self.W_w_attention_sentence) + self.W_b_attention_sentence) # shape:[None*num_sentence,hidden_size*2]
hidden_representation = tf.reshape(hidden_representation, shape=[-1, self.num_sentences,
self.hidden_size * 2]) # [None,num_sentence,hidden_size*2]
# attention process:1.get logits for each sentence in the doc.2.get possibility distribution for each sentence in the doc.3.get weighted sum for the sentences as doc representation.
# 1) get logits for each word in the sentence.
hidden_state_context_similiarity = tf.multiply(hidden_representation,
self.context_vecotor_sentence) # shape:[None,num_sentence,hidden_size*2]
attention_logits = tf.reduce_sum(hidden_state_context_similiarity,
axis=2) # shape:[None,num_sentence]. that is get logit for each num_sentence.
# subtract max for numerical stability (softmax is shift invariant). tf.reduce_max:computes the maximum of elements across dimensions of a tensor.
attention_logits_max = tf.reduce_max(attention_logits, axis=1, keep_dims=True) # shape:[None,1]
# 2) get possibility distribution for each word in the sentence.
p_attention = tf.nn.softmax(attention_logits - attention_logits_max) # shape:[None,num_sentence]
# 3) get weighted hidden state by attention vector(sentence level)
p_attention_expanded = tf.expand_dims(p_attention, axis=2) # shape:[None,num_sentence,1]
sentence_representation = tf.multiply(p_attention_expanded,
hidden_state_) # shape:[None,num_sentence,hidden_size*2]<---p_attention_expanded:[None,num_sentence,1];hidden_state_:[None,num_sentence,hidden_size*2]
sentence_representation = tf.reduce_sum(sentence_representation, axis=1) # shape:[None,hidden_size*2]
return sentence_representation # shape:[None,hidden_size*2]