Skip to content

Commit 09bbec0

Browse files
author
Rafael Valle
committedMay 3, 2018
adding python files
1 parent 25a267d commit 09bbec0

15 files changed

+2134
-0
lines changed
 

‎audio_processing.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
import numpy as np
3+
from scipy.signal import get_window
4+
import librosa.util as librosa_util
5+
6+
7+
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
8+
n_fft=800, dtype=np.float32, norm=None):
9+
"""
10+
# from librosa 0.6
11+
Compute the sum-square envelope of a window function at a given hop length.
12+
13+
This is used to estimate modulation effects induced by windowing
14+
observations in short-time fourier transforms.
15+
16+
Parameters
17+
----------
18+
window : string, tuple, number, callable, or list-like
19+
Window specification, as in `get_window`
20+
21+
n_frames : int > 0
22+
The number of analysis frames
23+
24+
hop_length : int > 0
25+
The number of samples to advance between frames
26+
27+
win_length : [optional]
28+
The length of the window function. By default, this matches `n_fft`.
29+
30+
n_fft : int > 0
31+
The length of each analysis frame.
32+
33+
dtype : np.dtype
34+
The data type of the output
35+
36+
Returns
37+
-------
38+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
39+
The sum-squared envelope of the window function
40+
"""
41+
if win_length is None:
42+
win_length = n_fft
43+
44+
n = n_fft + hop_length * (n_frames - 1)
45+
x = np.zeros(n, dtype=dtype)
46+
47+
# Compute the squared window at the desired length
48+
win_sq = get_window(window, win_length, fftbins=True)
49+
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
50+
win_sq = librosa_util.pad_center(win_sq, n_fft)
51+
52+
# Fill the envelope
53+
for i in range(n_frames):
54+
sample = i * hop_length
55+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
56+
return x
57+
58+
59+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
60+
"""
61+
PARAMS
62+
------
63+
magnitudes: spectrogram magnitudes
64+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
65+
"""
66+
67+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
68+
angles = angles.astype(np.float32)
69+
angles = torch.autograd.Variable(torch.from_numpy(angles))
70+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
71+
72+
for i in range(n_iters):
73+
_, angles = stft_fn.transform(signal)
74+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
75+
return signal
76+
77+
78+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
79+
"""
80+
PARAMS
81+
------
82+
C: compression factor
83+
"""
84+
return torch.log(torch.clamp(x, min=clip_val) * C)
85+
86+
87+
def dynamic_range_decompression(x, C=1):
88+
"""
89+
PARAMS
90+
------
91+
C: compression factor used to compress
92+
"""
93+
return torch.exp(x) / C

‎data_utils.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import random
2+
import torch
3+
import torch.utils.data
4+
5+
import layers
6+
from utils import load_wav_to_torch, load_filepaths_and_text
7+
from text import text_to_sequence
8+
9+
10+
class TextMelLoader(torch.utils.data.Dataset):
11+
"""
12+
1) loads audio,text pairs
13+
2) normalizes text and converts them to sequences of one-hot vectors
14+
3) computes mel-spectrograms from audio files.
15+
"""
16+
def __init__(self, audiopaths_and_text, hparams, shuffle=True):
17+
self.audiopaths_and_text = load_filepaths_and_text(
18+
audiopaths_and_text, hparams.sort_by_length)
19+
self.text_cleaners = hparams.text_cleaners
20+
self.max_wav_value = hparams.max_wav_value
21+
self.sampling_rate = hparams.sampling_rate
22+
self.stft = layers.TacotronSTFT(
23+
hparams.filter_length, hparams.hop_length, hparams.win_length,
24+
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
25+
hparams.mel_fmax)
26+
random.seed(1234)
27+
if shuffle:
28+
random.shuffle(self.audiopaths_and_text)
29+
30+
def get_mel_text_pair(self, audiopath_and_text):
31+
# separate filename and text
32+
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
33+
text = self.get_text(text)
34+
mel = self.get_mel(audiopath)
35+
return (text, mel)
36+
37+
def get_mel(self, filename):
38+
audio = load_wav_to_torch(filename, self.sampling_rate)
39+
audio_norm = audio / self.max_wav_value
40+
audio_norm = audio_norm.unsqueeze(0)
41+
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
42+
melspec = self.stft.mel_spectrogram(audio_norm)
43+
melspec = torch.squeeze(melspec, 0)
44+
return melspec
45+
46+
def get_text(self, text):
47+
text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
48+
return text_norm
49+
50+
def __getitem__(self, index):
51+
return self.get_mel_text_pair(self.audiopaths_and_text[index])
52+
53+
def __len__(self):
54+
return len(self.audiopaths_and_text)
55+
56+
57+
class TextMelCollate():
58+
""" Zero-pads model inputs and targets based on number of frames per setep
59+
"""
60+
def __init__(self, n_frames_per_step):
61+
self.n_frames_per_step = n_frames_per_step
62+
63+
def __call__(self, batch):
64+
"""Collate's training batch from normalized text and mel-spectrogram
65+
PARAMS
66+
------
67+
batch: [text_normalized, mel_normalized]
68+
"""
69+
# Right zero-pad all one-hot text sequences to max input length
70+
input_lengths, ids_sorted_decreasing = torch.sort(
71+
torch.LongTensor([len(x[0]) for x in batch]),
72+
dim=0, descending=True)
73+
max_input_len = input_lengths[0]
74+
75+
text_padded = torch.LongTensor(len(batch), max_input_len)
76+
text_padded.zero_()
77+
for i in range(len(ids_sorted_decreasing)):
78+
text = batch[ids_sorted_decreasing[i]][0]
79+
text_padded[i, :text.size(0)] = text
80+
81+
# Right zero-pad mel-spec with extra single zero vector to mark the end
82+
num_mels = batch[0][1].size(0)
83+
max_target_len = max([x[1].size(1) for x in batch]) + 1
84+
if max_target_len % self.n_frames_per_step != 0:
85+
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
86+
assert max_target_len % self.n_frames_per_step == 0
87+
88+
# include mel padded and gate padded
89+
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
90+
mel_padded.zero_()
91+
gate_padded = torch.FloatTensor(len(batch), max_target_len)
92+
gate_padded.zero_()
93+
output_lengths = torch.LongTensor(len(batch))
94+
for i in range(len(ids_sorted_decreasing)):
95+
mel = batch[ids_sorted_decreasing[i]][1]
96+
mel_padded[i, :, :mel.size(1)] = mel
97+
gate_padded[i, mel.size(1):] = 1
98+
output_lengths[i] = mel.size(1)
99+
100+
return text_padded, input_lengths, mel_padded, gate_padded, \
101+
output_lengths

‎distributed.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import torch
2+
import torch.distributed as dist
3+
from torch.nn.modules import Module
4+
5+
def _flatten_dense_tensors(tensors):
6+
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
7+
same dense type.
8+
Since inputs are dense, the resulting tensor will be a concatenated 1D
9+
buffer. Element-wise operation on this buffer will be equivalent to
10+
operating individually.
11+
Arguments:
12+
tensors (Iterable[Tensor]): dense tensors to flatten.
13+
Returns:
14+
A contiguous 1D buffer containing input tensors.
15+
"""
16+
if len(tensors) == 1:
17+
return tensors[0].contiguous().view(-1)
18+
flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
19+
return flat
20+
21+
def _unflatten_dense_tensors(flat, tensors):
22+
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
23+
same dense type, and that flat is given by _flatten_dense_tensors.
24+
Arguments:
25+
flat (Tensor): flattened dense tensors to unflatten.
26+
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
27+
unflatten flat.
28+
Returns:
29+
Unflattened dense tensors with sizes same as tensors and values from
30+
flat.
31+
"""
32+
outputs = []
33+
offset = 0
34+
for tensor in tensors:
35+
numel = tensor.numel()
36+
outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
37+
offset += numel
38+
return tuple(outputs)
39+
40+
41+
'''
42+
This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
43+
launcher included with this example. It assumes that your run is using multiprocess with 1
44+
GPU/process, that the model is on the correct device, and that torch.set_device has been
45+
used to set the device.
46+
47+
Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
48+
and will be allreduced at the finish of the backward pass.
49+
'''
50+
class DistributedDataParallel(Module):
51+
52+
def __init__(self, module):
53+
super(DistributedDataParallel, self).__init__()
54+
#fallback for PyTorch 0.3
55+
if not hasattr(dist, '_backend'):
56+
self.warn_on_half = True
57+
else:
58+
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
59+
60+
self.module = module
61+
62+
for p in self.module.state_dict().values():
63+
if not torch.is_tensor(p):
64+
continue
65+
dist.broadcast(p, 0)
66+
67+
def allreduce_params():
68+
if(self.needs_reduction):
69+
self.needs_reduction = False
70+
buckets = {}
71+
for param in self.module.parameters():
72+
if param.requires_grad and param.grad is not None:
73+
tp = type(param.data)
74+
if tp not in buckets:
75+
buckets[tp] = []
76+
buckets[tp].append(param)
77+
if self.warn_on_half:
78+
if torch.cuda.HalfTensor in buckets:
79+
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
80+
" It is recommended to use the NCCL backend in this case. This currently requires" +
81+
"PyTorch built from top of tree master.")
82+
self.warn_on_half = False
83+
84+
for tp in buckets:
85+
bucket = buckets[tp]
86+
grads = [param.grad.data for param in bucket]
87+
coalesced = _flatten_dense_tensors(grads)
88+
dist.all_reduce(coalesced)
89+
coalesced /= dist.get_world_size()
90+
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
91+
buf.copy_(synced)
92+
93+
for param in list(self.module.parameters()):
94+
def allreduce_hook(*unused):
95+
param._execution_engine.queue_callback(allreduce_params)
96+
if param.requires_grad:
97+
param.register_hook(allreduce_hook)
98+
99+
def forward(self, *inputs, **kwargs):
100+
self.needs_reduction = True
101+
return self.module(*inputs, **kwargs)
102+
103+
'''
104+
def _sync_buffers(self):
105+
buffers = list(self.module._all_buffers())
106+
if len(buffers) > 0:
107+
# cross-node buffer sync
108+
flat_buffers = _flatten_dense_tensors(buffers)
109+
dist.broadcast(flat_buffers, 0)
110+
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
111+
buf.copy_(synced)
112+
def train(self, mode=True):
113+
# Clear NCCL communicator and CUDA event cache of the default group ID,
114+
# These cache will be recreated at the later call. This is currently a
115+
# work-around for a potential NCCL deadlock.
116+
if dist._backend == dist.dist_backend.NCCL:
117+
dist._clear_group_cache()
118+
super(DistributedDataParallel, self).train(mode)
119+
self.module.train(mode)
120+
'''

‎fp16_optimizer.py

+381
Large diffs are not rendered by default.

‎hparams.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import tensorflow as tf
2+
from text import symbols
3+
4+
5+
def create_hparams(hparams_string=None, verbose=False):
6+
"""Create model hyperparameters. Parse nondefault from given string."""
7+
8+
hparams = tf.contrib.training.HParams(
9+
################################
10+
# Experiment Parameters #
11+
################################
12+
epochs=500,
13+
iters_per_checkpoint=500,
14+
seed=1234,
15+
dynamic_loss_scaling=True,
16+
fp16_run=False,
17+
distributed_run=False,
18+
dist_backend="nccl",
19+
dist_url="file://distributed.dpt",
20+
cudnn_enabled=True,
21+
cudnn_benchmark=False,
22+
23+
################################
24+
# Data Parameters #
25+
################################
26+
training_files='ljs_audio_text_train_filelist.txt',
27+
validation_files='ljs_audio_text_val_filelist.txt',
28+
text_cleaners=['english_cleaners'],
29+
sort_by_length=False,
30+
31+
################################
32+
# Audio Parameters #
33+
################################
34+
max_wav_value=32768.0,
35+
sampling_rate=22050,
36+
filter_length=1024,
37+
hop_length=256,
38+
win_length=1024,
39+
n_mel_channels=80,
40+
mel_fmin=0.0,
41+
mel_fmax=None, # if None, half the sampling rate
42+
43+
################################
44+
# Model Parameters #
45+
################################
46+
n_symbols=len(symbols),
47+
symbols_embedding_dim=512,
48+
49+
# Encoder parameters
50+
encoder_kernel_size=5,
51+
encoder_n_convolutions=3,
52+
encoder_embedding_dim=512,
53+
54+
# Decoder parameters
55+
n_frames_per_step=1,
56+
decoder_rnn_dim=1024,
57+
prenet_dim=256,
58+
max_decoder_steps=1000,
59+
gate_threshold=0.6,
60+
61+
# Attention parameters
62+
attention_rnn_dim=1024,
63+
attention_dim=128,
64+
65+
# Location Layer parameters
66+
attention_location_n_filters=32,
67+
attention_location_kernel_size=31,
68+
69+
# Mel-post processing network parameters
70+
postnet_embedding_dim=512,
71+
postnet_kernel_size=5,
72+
postnet_n_convolutions=5,
73+
74+
################################
75+
# Optimization Hyperparameters #
76+
################################
77+
learning_rate=1e-3,
78+
weight_decay=1e-6,
79+
grad_clip_thresh=1,
80+
batch_size=48,
81+
mask_padding=False # set model's padded outputs to padded values
82+
)
83+
84+
if hparams_string:
85+
tf.logging.info('Parsing command line hparams: %s', hparams_string)
86+
hparams.parse(hparams_string)
87+
88+
if verbose:
89+
tf.logging.info('Final parsed hparams: %s', hparams.values())
90+
91+
return hparams

‎layers.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
from librosa.filters import mel as librosa_mel_fn
3+
from audio_processing import dynamic_range_compression
4+
from audio_processing import dynamic_range_decompression
5+
from stft import STFT
6+
7+
8+
class LinearNorm(torch.nn.Module):
9+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
10+
super(LinearNorm, self).__init__()
11+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
12+
13+
torch.nn.init.xavier_uniform(
14+
self.linear_layer.weight,
15+
gain=torch.nn.init.calculate_gain(w_init_gain))
16+
17+
def forward(self, x):
18+
return self.linear_layer(x)
19+
20+
21+
class ConvNorm(torch.nn.Module):
22+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
23+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
24+
super(ConvNorm, self).__init__()
25+
if padding is None:
26+
assert(kernel_size % 2 == 1)
27+
padding = int(dilation * (kernel_size - 1) / 2)
28+
29+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
30+
kernel_size=kernel_size, stride=stride,
31+
padding=padding, dilation=dilation,
32+
bias=bias)
33+
34+
torch.nn.init.xavier_uniform(
35+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
36+
37+
def forward(self, signal):
38+
conv_signal = self.conv(signal)
39+
return conv_signal
40+
41+
42+
class TacotronSTFT(torch.nn.Module):
43+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
44+
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
45+
mel_fmax=None):
46+
super(TacotronSTFT, self).__init__()
47+
self.n_mel_channels = n_mel_channels
48+
self.sampling_rate = sampling_rate
49+
self.stft_fn = STFT(filter_length, hop_length, win_length)
50+
mel_basis = librosa_mel_fn(
51+
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
52+
mel_basis = torch.from_numpy(mel_basis).float()
53+
self.register_buffer('mel_basis', mel_basis)
54+
55+
def spectral_normalize(self, magnitudes):
56+
output = dynamic_range_compression(magnitudes)
57+
return output
58+
59+
def spectral_de_normalize(self, magnitudes):
60+
output = dynamic_range_decompression(magnitudes)
61+
return output
62+
63+
def mel_spectrogram(self, y):
64+
"""Computes mel-spectrograms from a batch of waves
65+
PARAMS
66+
------
67+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
68+
69+
RETURNS
70+
-------
71+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
72+
"""
73+
assert(torch.min(y.data) >= -1)
74+
assert(torch.max(y.data) <= 1)
75+
76+
magnitudes, phases = self.stft_fn.transform(y)
77+
magnitudes = magnitudes.data
78+
mel_output = torch.matmul(self.mel_basis, magnitudes)
79+
mel_output = self.spectral_normalize(mel_output)
80+
return mel_output

‎logger.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import random
2+
import torch.nn.functional as F
3+
from tensorboardX import SummaryWriter
4+
from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy
5+
from plotting_utils import plot_gate_outputs_to_numpy
6+
7+
8+
class Tacotron2Logger(SummaryWriter):
9+
def __init__(self, logdir):
10+
super(Tacotron2Logger, self).__init__(logdir)
11+
12+
def log_training(self, reduced_loss, grad_norm, learning_rate, duration,
13+
iteration):
14+
self.add_scalar("training.loss", reduced_loss, iteration)
15+
self.add_scalar("grad.norm", grad_norm, iteration)
16+
self.add_scalar("learning.rate", learning_rate, iteration)
17+
self.add_scalar("duration", duration, iteration)
18+
19+
def log_validation(self, reduced_loss, model, y, y_pred, iteration):
20+
self.add_scalar("validation.loss", reduced_loss, iteration)
21+
_, mel_outputs, gate_outputs, alignments = y_pred
22+
mel_targets, gate_targets = y
23+
24+
# plot distribution of parameters
25+
for tag, value in model.named_parameters():
26+
tag = tag.replace('.', '/')
27+
self.add_histogram(tag, value.data.cpu().numpy(), iteration)
28+
29+
# plot alignment, mel target and predicted, gate target and predicted
30+
idx = random.randint(0, alignments.size(0) - 1)
31+
self.add_image(
32+
"alignment",
33+
plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
34+
iteration)
35+
self.add_image(
36+
"mel_target",
37+
plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
38+
iteration)
39+
self.add_image(
40+
"mel_predicted",
41+
plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
42+
iteration)
43+
self.add_image(
44+
"gate",
45+
plot_gate_outputs_to_numpy(
46+
gate_targets[idx].data.cpu().numpy(),
47+
F.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
48+
iteration)

‎loss_function.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from torch import nn
2+
3+
4+
class Tacotron2Loss(nn.Module):
5+
def __init__(self):
6+
super(Tacotron2Loss, self).__init__()
7+
8+
def forward(self, model_output, targets):
9+
mel_target, gate_target = targets[0], targets[1]
10+
mel_target.requires_grad = False
11+
gate_target.requires_grad = False
12+
gate_target = gate_target.view(-1, 1)
13+
14+
mel_out, mel_out_postnet, gate_out, _ = model_output
15+
gate_out = gate_out.view(-1, 1)
16+
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
17+
nn.MSELoss()(mel_out_postnet, mel_target)
18+
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
19+
return mel_loss + gate_loss

‎loss_scaler.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import torch
2+
3+
class LossScaler:
4+
5+
def __init__(self, scale=1):
6+
self.cur_scale = scale
7+
8+
# `params` is a list / generator of torch.Variable
9+
def has_overflow(self, params):
10+
return False
11+
12+
# `x` is a torch.Tensor
13+
def _has_inf_or_nan(x):
14+
return False
15+
16+
# `overflow` is boolean indicating whether we overflowed in gradient
17+
def update_scale(self, overflow):
18+
pass
19+
20+
@property
21+
def loss_scale(self):
22+
return self.cur_scale
23+
24+
def scale_gradient(self, module, grad_in, grad_out):
25+
return tuple(self.loss_scale * g for g in grad_in)
26+
27+
def backward(self, loss):
28+
scaled_loss = loss*self.loss_scale
29+
scaled_loss.backward()
30+
31+
class DynamicLossScaler:
32+
33+
def __init__(self,
34+
init_scale=2**32,
35+
scale_factor=2.,
36+
scale_window=1000):
37+
self.cur_scale = init_scale
38+
self.cur_iter = 0
39+
self.last_overflow_iter = -1
40+
self.scale_factor = scale_factor
41+
self.scale_window = scale_window
42+
43+
# `params` is a list / generator of torch.Variable
44+
def has_overflow(self, params):
45+
# return False
46+
for p in params:
47+
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
48+
return True
49+
50+
return False
51+
52+
# `x` is a torch.Tensor
53+
def _has_inf_or_nan(x):
54+
inf_count = torch.sum(x.abs() == float('inf'))
55+
if inf_count > 0:
56+
return True
57+
nan_count = torch.sum(x != x)
58+
return nan_count > 0
59+
60+
# `overflow` is boolean indicating whether we overflowed in gradient
61+
def update_scale(self, overflow):
62+
if overflow:
63+
#self.cur_scale /= self.scale_factor
64+
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
65+
self.last_overflow_iter = self.cur_iter
66+
else:
67+
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
68+
self.cur_scale *= self.scale_factor
69+
# self.cur_scale = 1
70+
self.cur_iter += 1
71+
72+
@property
73+
def loss_scale(self):
74+
return self.cur_scale
75+
76+
def scale_gradient(self, module, grad_in, grad_out):
77+
return tuple(self.loss_scale * g for g in grad_in)
78+
79+
def backward(self, loss):
80+
scaled_loss = loss*self.loss_scale
81+
scaled_loss.backward()
82+
83+
##############################################################
84+
# Example usage below here -- assuming it's in a separate file
85+
##############################################################
86+
if __name__ == "__main__":
87+
import torch
88+
from torch.autograd import Variable
89+
from dynamic_loss_scaler import DynamicLossScaler
90+
91+
# N is batch size; D_in is input dimension;
92+
# H is hidden dimension; D_out is output dimension.
93+
N, D_in, H, D_out = 64, 1000, 100, 10
94+
95+
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
96+
x = Variable(torch.randn(N, D_in), requires_grad=False)
97+
y = Variable(torch.randn(N, D_out), requires_grad=False)
98+
99+
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
100+
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
101+
parameters = [w1, w2]
102+
103+
learning_rate = 1e-6
104+
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
105+
loss_scaler = DynamicLossScaler()
106+
107+
for t in range(500):
108+
y_pred = x.mm(w1).clamp(min=0).mm(w2)
109+
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
110+
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
111+
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
112+
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
113+
114+
# Run backprop
115+
optimizer.zero_grad()
116+
loss.backward()
117+
118+
# Check for overflow
119+
has_overflow = DynamicLossScaler.has_overflow(parameters)
120+
121+
# If no overflow, unscale grad and update as usual
122+
if not has_overflow:
123+
for param in parameters:
124+
param.grad.data.mul_(1. / loss_scaler.loss_scale)
125+
optimizer.step()
126+
# Otherwise, don't do anything -- ie, skip iteration
127+
else:
128+
print('OVERFLOW!')
129+
130+
# Update loss scale for next iteration
131+
loss_scaler.update_scale(has_overflow)
132+

‎model.py

+541
Large diffs are not rendered by default.

‎multiproc.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import time
2+
import torch
3+
import sys
4+
import subprocess
5+
6+
argslist = list(sys.argv)[1:]
7+
num_gpus = torch.cuda.device_count()
8+
argslist.append('--n_gpus={}'.format(num_gpus))
9+
workers = []
10+
job_id = time.strftime("%Y_%m_%d-%H%M%S")
11+
argslist.append("--group_name=group_{}".format(job_id))
12+
13+
for i in range(num_gpus):
14+
argslist.append('--rank={}'.format(i))
15+
stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i),
16+
"w")
17+
print(argslist)
18+
p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)
19+
workers.append(p)
20+
argslist = argslist[:-1]
21+
22+
for p in workers:
23+
p.wait()

‎plotting_utils.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import matplotlib
2+
matplotlib.use("Agg")
3+
import matplotlib.pylab as plt
4+
import numpy as np
5+
6+
7+
def save_figure_to_numpy(fig):
8+
# save it to a numpy array.
9+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
10+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
11+
return data
12+
13+
14+
def plot_alignment_to_numpy(alignment, info=None):
15+
fig, ax = plt.subplots(figsize=(6, 4))
16+
im = ax.imshow(alignment, aspect='auto', origin='lower',
17+
interpolation='none')
18+
fig.colorbar(im, ax=ax)
19+
xlabel = 'Decoder timestep'
20+
if info is not None:
21+
xlabel += '\n\n' + info
22+
plt.xlabel(xlabel)
23+
plt.ylabel('Encoder timestep')
24+
plt.tight_layout()
25+
26+
fig.canvas.draw()
27+
data = save_figure_to_numpy(fig)
28+
plt.close()
29+
return data
30+
31+
32+
def plot_spectrogram_to_numpy(spectrogram):
33+
fig, ax = plt.subplots(figsize=(12, 3))
34+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
35+
interpolation='none')
36+
plt.colorbar(im, ax=ax)
37+
plt.xlabel("Frames")
38+
plt.ylabel("Channels")
39+
plt.tight_layout()
40+
41+
fig.canvas.draw()
42+
data = save_figure_to_numpy(fig)
43+
plt.close()
44+
return data
45+
46+
47+
def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
48+
fig, ax = plt.subplots(figsize=(12, 3))
49+
ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5,
50+
color='green', marker='+', s=1, label='target')
51+
ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5,
52+
color='red', marker='.', s=1, label='predicted')
53+
54+
plt.xlabel("Frames (Green target, Red predicted)")
55+
plt.ylabel("Gate State")
56+
plt.tight_layout()
57+
58+
fig.canvas.draw()
59+
data = save_figure_to_numpy(fig)
60+
plt.close()
61+
return data

‎stft.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
BSD 3-Clause License
3+
4+
Copyright (c) 2017, Prem Seetharaman
5+
All rights reserved.
6+
7+
* Redistribution and use in source and binary forms, with or without
8+
modification, are permitted provided that the following conditions are met:
9+
10+
* Redistributions of source code must retain the above copyright notice,
11+
this list of conditions and the following disclaimer.
12+
13+
* Redistributions in binary form must reproduce the above copyright notice, this
14+
list of conditions and the following disclaimer in the
15+
documentation and/or other materials provided with the distribution.
16+
17+
* Neither the name of the copyright holder nor the names of its
18+
contributors may be used to endorse or promote products derived from this
19+
software without specific prior written permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
"""
32+
33+
import torch
34+
import numpy as np
35+
import torch.nn.functional as F
36+
from torch.autograd import Variable
37+
from scipy.signal import get_window
38+
from librosa.util import pad_center, tiny
39+
from audio_processing import window_sumsquare
40+
41+
42+
class STFT(torch.nn.Module):
43+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
44+
def __init__(self, filter_length=800, hop_length=200, win_length=800,
45+
window='hann'):
46+
super(STFT, self).__init__()
47+
self.filter_length = filter_length
48+
self.hop_length = hop_length
49+
self.win_length = win_length
50+
self.window = window
51+
self.forward_transform = None
52+
scale = self.filter_length / self.hop_length
53+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
54+
55+
cutoff = int((self.filter_length / 2 + 1))
56+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
57+
np.imag(fourier_basis[:cutoff, :])])
58+
59+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
60+
inverse_basis = torch.FloatTensor(
61+
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
62+
63+
if window is not None:
64+
assert(win_length >= filter_length)
65+
# get window and zero center pad it to filter_length
66+
fft_window = get_window(window, win_length, fftbins=True)
67+
fft_window = pad_center(fft_window, filter_length)
68+
fft_window = torch.from_numpy(fft_window).float()
69+
70+
# window the bases
71+
forward_basis *= fft_window
72+
inverse_basis *= fft_window
73+
74+
self.register_buffer('forward_basis', forward_basis.float())
75+
self.register_buffer('inverse_basis', inverse_basis.float())
76+
77+
def transform(self, input_data):
78+
num_batches = input_data.size(0)
79+
num_samples = input_data.size(1)
80+
81+
self.num_samples = num_samples
82+
83+
# similar to librosa, reflect-pad the input
84+
input_data = input_data.view(num_batches, 1, num_samples)
85+
input_data = F.pad(
86+
input_data.unsqueeze(1),
87+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
88+
mode='reflect')
89+
input_data = input_data.squeeze(1)
90+
91+
forward_transform = F.conv1d(
92+
input_data,
93+
Variable(self.forward_basis, requires_grad=False),
94+
stride=self.hop_length,
95+
padding=0)
96+
97+
cutoff = int((self.filter_length / 2) + 1)
98+
real_part = forward_transform[:, :cutoff, :]
99+
imag_part = forward_transform[:, cutoff:, :]
100+
101+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
102+
phase = torch.autograd.Variable(
103+
torch.atan2(imag_part.data, real_part.data))
104+
105+
return magnitude, phase
106+
107+
def inverse(self, magnitude, phase):
108+
recombine_magnitude_phase = torch.cat(
109+
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
110+
111+
inverse_transform = F.conv_transpose1d(
112+
recombine_magnitude_phase,
113+
Variable(self.inverse_basis, requires_grad=False),
114+
stride=self.hop_length,
115+
padding=0)
116+
117+
if self.window is not None:
118+
window_sum = window_sumsquare(
119+
self.window, magnitude.size(-1), hop_length=self.hop_length,
120+
win_length=self.win_length, n_fft=self.filter_length,
121+
dtype=np.float32)
122+
# remove modulation effects
123+
approx_nonzero_indices = torch.from_numpy(
124+
np.where(window_sum > tiny(window_sum))[0])
125+
window_sum = torch.autograd.Variable(
126+
torch.from_numpy(window_sum), requires_grad=False)
127+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
128+
129+
# scale by hop ratio
130+
inverse_transform *= float(self.filter_length) / self.hop_length
131+
132+
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
133+
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
134+
135+
return inverse_transform
136+
137+
def forward(self, input_data):
138+
self.magnitude, self.phase = self.transform(input_data)
139+
reconstruction = self.inverse(self.magnitude, self.phase)
140+
return reconstruction

‎train.py

+272
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
import os
2+
import time
3+
import argparse
4+
import math
5+
6+
import torch
7+
from distributed import DistributedDataParallel
8+
from torch.utils.data.distributed import DistributedSampler
9+
from torch.nn import DataParallel
10+
from torch.utils.data import DataLoader
11+
12+
from fp16_optimizer import FP16_Optimizer
13+
14+
from model import Tacotron2
15+
from data_utils import TextMelLoader, TextMelCollate
16+
from loss_function import Tacotron2Loss
17+
from logger import Tacotron2Logger
18+
from hparams import create_hparams
19+
20+
21+
def batchnorm_to_float(module):
22+
"""Converts batch norm modules to FP32"""
23+
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
24+
module.float()
25+
for child in module.children():
26+
batchnorm_to_float(child)
27+
return module
28+
29+
30+
def reduce_tensor(tensor, num_gpus):
31+
rt = tensor.clone()
32+
torch.distributed.all_reduce(rt, op=torch.distributed.reduce_op.SUM)
33+
rt /= num_gpus
34+
return rt
35+
36+
37+
def init_distributed(hparams, n_gpus, rank, group_name):
38+
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
39+
print("Initializing distributed")
40+
# Set cuda device so everything is done on the right GPU.
41+
torch.cuda.set_device(rank % torch.cuda.device_count())
42+
43+
# Initialize distributed communication
44+
torch.distributed.init_process_group(
45+
backend=hparams.dist_backend, init_method=hparams.dist_url,
46+
world_size=n_gpus, rank=rank, group_name=group_name)
47+
48+
print("Done initializing distributed")
49+
50+
51+
def prepare_dataloaders(hparams):
52+
# Get data, data loaders and collate function ready
53+
trainset = TextMelLoader(hparams.training_files, hparams)
54+
valset = TextMelLoader(hparams.validation_files, hparams)
55+
collate_fn = TextMelCollate(hparams.n_frames_per_step)
56+
57+
train_sampler = DistributedSampler(trainset) \
58+
if hparams.distributed_run else None
59+
60+
train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
61+
sampler=train_sampler,
62+
batch_size=hparams.batch_size, pin_memory=False,
63+
drop_last=True, collate_fn=collate_fn)
64+
return train_loader, valset, collate_fn
65+
66+
67+
def prepare_directories_and_logger(output_directory, log_directory, rank):
68+
if rank == 0:
69+
if not os.path.isdir(output_directory):
70+
os.makedirs(output_directory)
71+
os.chmod(output_directory, 0o775)
72+
logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
73+
else:
74+
logger = None
75+
return logger
76+
77+
78+
def load_model(hparams):
79+
model = Tacotron2(hparams).cuda()
80+
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
81+
model = DistributedDataParallel(model) \
82+
if hparams.distributed_run else DataParallel(model)
83+
return model
84+
85+
86+
def warm_start_model(checkpoint_path, model):
87+
assert os.path.isfile(checkpoint_path)
88+
print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
89+
checkpoint_dict = torch.load(checkpoint_path)
90+
model.load_state_dict(checkpoint_dict['state_dict'])
91+
return model
92+
93+
94+
def load_checkpoint(checkpoint_path, model, optimizer):
95+
assert os.path.isfile(checkpoint_path)
96+
print("Loading checkpoint '{}'".format(checkpoint_path))
97+
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
98+
model.load_state_dict(checkpoint_dict['state_dict'])
99+
optimizer.load_state_dict(checkpoint_dict['optimizer'])
100+
learning_rate = checkpoint_dict['learning_rate']
101+
iteration = checkpoint_dict['iteration']
102+
print("Loaded checkpoint '{}' from iteration {}" .format(
103+
checkpoint_path, iteration))
104+
return model, optimizer, learning_rate, iteration
105+
106+
107+
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
108+
print("Saving model and optimizer state at iteration {} to {}".format(
109+
iteration, filepath))
110+
torch.save({'iteration': iteration,
111+
'state_dict': model.state_dict(),
112+
'optimizer': optimizer.state_dict(),
113+
'learning_rate': learning_rate}, filepath)
114+
115+
116+
def validate(model, criterion, valset, iteration, batch_size, n_gpus,
117+
collate_fn, logger, distributed_run, rank):
118+
"""Handles all the validation scoring and printing"""
119+
model.eval()
120+
with torch.no_grad():
121+
val_sampler = DistributedSampler(valset) if distributed_run else None
122+
val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
123+
shuffle=False, batch_size=batch_size,
124+
pin_memory=False, collate_fn=collate_fn)
125+
126+
val_loss = 0.0
127+
for i, batch in enumerate(val_loader):
128+
x, y = model.module.parse_batch(batch)
129+
y_pred = model(x)
130+
loss = criterion(y_pred, y)
131+
reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \
132+
if distributed_run else loss.data[0]
133+
val_loss += reduced_val_loss
134+
val_loss = val_loss / (i + 1)
135+
136+
model.train()
137+
return val_loss
138+
139+
140+
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
141+
rank, group_name, hparams):
142+
"""Training and validation logging results to tensorboard and stdout
143+
144+
Params
145+
------
146+
output_directory (string): directory to save checkpoints
147+
log_directory (string) directory to save tensorboard logs
148+
checkpoint_path(string): checkpoint path
149+
n_gpus (int): number of gpus
150+
rank (int): rank of current gpu
151+
hparams (object): comma separated list of "name=value" pairs.
152+
"""
153+
if hparams.distributed_run:
154+
init_distributed(hparams, n_gpus, rank, group_name)
155+
156+
torch.manual_seed(hparams.seed)
157+
torch.cuda.manual_seed(hparams.seed)
158+
159+
model = load_model(hparams)
160+
learning_rate = hparams.learning_rate
161+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
162+
weight_decay=hparams.weight_decay)
163+
if hparams.fp16_run:
164+
optimizer = FP16_Optimizer(
165+
optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling)
166+
167+
criterion = Tacotron2Loss()
168+
169+
logger = prepare_directories_and_logger(
170+
output_directory, log_directory, rank)
171+
172+
train_loader, valset, collate_fn = prepare_dataloaders(hparams)
173+
174+
# Load checkpoint if one exists
175+
iteration = 0
176+
epoch_offset = 0
177+
if checkpoint_path is not None:
178+
if warm_start:
179+
model = warm_start_model(checkpoint_path, model)
180+
else:
181+
model, optimizer, learning_rate, iteration = load_checkpoint(
182+
checkpoint_path, model, optimizer)
183+
iteration += 1 # next iteration is iteration + 1
184+
epoch_offset = max(0, int(iteration / len(train_loader)))
185+
186+
model.train()
187+
# ================ MAIN TRAINNIG LOOP! ===================
188+
for epoch in range(epoch_offset, hparams.epochs):
189+
print("Epoch: {}".format(epoch))
190+
for i, batch in enumerate(train_loader):
191+
start = time.perf_counter()
192+
for param_group in optimizer.param_groups:
193+
param_group['lr'] = learning_rate
194+
195+
model.zero_grad()
196+
x, y = model.module.parse_batch(batch)
197+
y_pred = model(x)
198+
loss = criterion(y_pred, y)
199+
reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \
200+
if hparams.distributed_run else loss.data[0]
201+
202+
if hparams.fp16_run:
203+
optimizer.backward(loss)
204+
grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh)
205+
else:
206+
loss.backward()
207+
grad_norm = torch.nn.utils.clip_grad_norm(
208+
model.module.parameters(), hparams.grad_clip_thresh)
209+
210+
optimizer.step()
211+
212+
overflow = optimizer.overflow if hparams.fp16_run else False
213+
214+
if not overflow and not math.isnan(reduced_loss) and rank == 0:
215+
duration = time.perf_counter() - start
216+
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
217+
iteration, reduced_loss, grad_norm, duration))
218+
219+
logger.log_training(
220+
reduced_loss, grad_norm, learning_rate, duration, iteration)
221+
222+
if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
223+
reduced_val_loss = validate(
224+
model, criterion, valset, iteration, hparams.batch_size,
225+
n_gpus, collate_fn, logger, hparams.distributed_run, rank)
226+
227+
if rank == 0:
228+
print("Validation loss {}: {:9f} ".format(
229+
iteration, reduced_val_loss))
230+
logger.log_validation(
231+
reduced_val_loss, model, y, y_pred, iteration)
232+
checkpoint_path = os.path.join(
233+
output_directory, "checkpoint_{}".format(iteration))
234+
save_checkpoint(model, optimizer, learning_rate, iteration,
235+
checkpoint_path)
236+
237+
iteration += 1
238+
239+
240+
if __name__ == '__main__':
241+
parser = argparse.ArgumentParser()
242+
parser.add_argument('-o', '--output_directory', type=str,
243+
help='directory to save checkpoints')
244+
parser.add_argument('-l', '--log_directory', type=str,
245+
help='directory to save tensorboard logs')
246+
parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
247+
required=False, help='checkpoint path')
248+
parser.add_argument('--warm_start', action='store_true',
249+
help='load the model only (warm start)')
250+
parser.add_argument('--n_gpus', type=int, default=1,
251+
required=False, help='number of gpus')
252+
parser.add_argument('--rank', type=int, default=0,
253+
required=False, help='rank of current gpu')
254+
parser.add_argument('--group_name', type=str, default='group_name',
255+
required=False, help='Distributed group name')
256+
parser.add_argument('--hparams', type=str,
257+
required=False, help='comma separated name=value pairs')
258+
259+
args = parser.parse_args()
260+
hparams = create_hparams(args.hparams)
261+
262+
torch.backends.cudnn.enabled = hparams.cudnn_enabled
263+
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
264+
265+
print("FP16 Run:", hparams.fp16_run)
266+
print("Dynamic Loss Scaling", hparams.dynamic_loss_scaling)
267+
print("Distributed Run:", hparams.distributed_run)
268+
print("cuDNN Enabled:", hparams.cudnn_enabled)
269+
print("cuDNN Benchmark:", hparams.cudnn_benchmark)
270+
271+
train(args.output_directory, args.log_directory, args.checkpoint_path,
272+
args.warm_start, args.n_gpus, args.rank, args.group_name, hparams)

‎utils.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import numpy as np
2+
from scipy.io.wavfile import read
3+
import torch
4+
5+
6+
def get_mask_from_lengths(lengths):
7+
max_len = torch.max(lengths)
8+
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).cuda()
9+
mask = (ids < lengths.unsqueeze(1)).byte()
10+
return mask
11+
12+
13+
def load_wav_to_torch(full_path, sr):
14+
sampling_rate, data = read(full_path)
15+
assert sr == sampling_rate, "{} SR doesn't match {} on path {}".format(
16+
sr, sampling_rate, full_path)
17+
return torch.FloatTensor(data.astype(np.float32))
18+
19+
20+
def load_filepaths_and_text(filename, sort_by_length, split="|"):
21+
with open(filename, encoding='utf-8') as f:
22+
filepaths_and_text = [line.strip().split(split) for line in f]
23+
24+
if sort_by_length:
25+
filepaths_and_text.sort(key=lambda x: len(x[1]))
26+
27+
return filepaths_and_text
28+
29+
30+
def to_gpu(x):
31+
x = x.contiguous().cuda(async=True)
32+
return torch.autograd.Variable(x)

0 commit comments

Comments
 (0)
Please sign in to comment.