Skip to content

Commit 064629c

Browse files
authored
Merge pull request #23 from NVIDIA/attention_full_mel
model.py: attending to full mel instead of prenet and dropout mel
2 parents da30fd8 + d5b6472 commit 064629c

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

model.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __init__(self, hparams):
221221
[hparams.prenet_dim, hparams.prenet_dim])
222222

223223
self.attention_rnn = nn.LSTMCell(
224-
hparams.prenet_dim + hparams.encoder_embedding_dim,
224+
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
225225
hparams.attention_rnn_dim)
226226

227227
self.attention_layer = Attention(
@@ -230,7 +230,7 @@ def __init__(self, hparams):
230230
hparams.attention_location_kernel_size)
231231

232232
self.decoder_rnn = nn.LSTMCell(
233-
hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
233+
hparams.prenet_dim + hparams.encoder_embedding_dim,
234234
hparams.decoder_rnn_dim, 1)
235235

236236
self.linear_projection = LinearNorm(
@@ -351,8 +351,7 @@ def decode(self, decoder_input):
351351
attention_weights:
352352
"""
353353

354-
decoder_input = self.prenet(decoder_input)
355-
cell_input = torch.cat((decoder_input, self.attention_context), -1)
354+
cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1)
356355
self.attention_hidden, self.attention_cell = self.attention_rnn(
357356
cell_input, (self.attention_hidden, self.attention_cell))
358357

@@ -364,8 +363,8 @@ def decode(self, decoder_input):
364363
attention_weights_cat, self.mask)
365364

366365
self.attention_weights_cum += self.attention_weights
367-
decoder_input = torch.cat(
368-
(self.attention_hidden, self.attention_context), -1)
366+
prenet_output = self.prenet(decoder_input)
367+
decoder_input = torch.cat((prenet_output, self.attention_context), -1)
369368
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
370369
decoder_input, (self.decoder_hidden, self.decoder_cell))
371370

0 commit comments

Comments
 (0)