@@ -221,7 +221,7 @@ def __init__(self, hparams):
221
221
[hparams .prenet_dim , hparams .prenet_dim ])
222
222
223
223
self .attention_rnn = nn .LSTMCell (
224
- hparams .prenet_dim + hparams .encoder_embedding_dim ,
224
+ hparams .decoder_rnn_dim + hparams .encoder_embedding_dim ,
225
225
hparams .attention_rnn_dim )
226
226
227
227
self .attention_layer = Attention (
@@ -230,7 +230,7 @@ def __init__(self, hparams):
230
230
hparams .attention_location_kernel_size )
231
231
232
232
self .decoder_rnn = nn .LSTMCell (
233
- hparams .attention_rnn_dim + hparams .encoder_embedding_dim ,
233
+ hparams .prenet_dim + hparams .encoder_embedding_dim ,
234
234
hparams .decoder_rnn_dim , 1 )
235
235
236
236
self .linear_projection = LinearNorm (
@@ -351,8 +351,7 @@ def decode(self, decoder_input):
351
351
attention_weights:
352
352
"""
353
353
354
- decoder_input = self .prenet (decoder_input )
355
- cell_input = torch .cat ((decoder_input , self .attention_context ), - 1 )
354
+ cell_input = torch .cat ((self .decoder_hidden , self .attention_context ), - 1 )
356
355
self .attention_hidden , self .attention_cell = self .attention_rnn (
357
356
cell_input , (self .attention_hidden , self .attention_cell ))
358
357
@@ -364,8 +363,8 @@ def decode(self, decoder_input):
364
363
attention_weights_cat , self .mask )
365
364
366
365
self .attention_weights_cum += self .attention_weights
367
- decoder_input = torch . cat (
368
- ( self . attention_hidden , self .attention_context ), - 1 )
366
+ prenet_output = self . prenet ( decoder_input )
367
+ decoder_input = torch . cat (( prenet_output , self .attention_context ), - 1 )
369
368
self .decoder_hidden , self .decoder_cell = self .decoder_rnn (
370
369
decoder_input , (self .decoder_hidden , self .decoder_cell ))
371
370
0 commit comments