-
Notifications
You must be signed in to change notification settings - Fork 148
/
Copy pathtrain.py
204 lines (186 loc) · 10.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import argparse
import functools
import os
import time
from datetime import timedelta
import paddle
import yaml
from loguru import logger
from paddle.io import DataLoader
from tqdm import tqdm
from visualdl import LogWriter
from data_utils.collate_fn import collate_fn
from data_utils.audio_featurizer import AudioFeaturizer
from data_utils.tokenizer import Tokenizer
from data_utils.reader import CustomDataset
from data_utils.sampler import SortagradBatchSampler
from decoders.ctc_greedy_search import ctc_greedy_search_batch
from model_utils.model import DeepSpeech2Model
from utils.checkpoint import load_checkpoint, load_pretrained, save_checkpoint
from utils.metrics import wer, cer
from utils.scheduler import WarmupLR
from utils.summary import summary
from utils.utils import add_arguments, print_arguments, dict_to_object
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "是否使用GPU训练")
add_arg('batch_size', int, 16, "训练每一批数据的大小")
add_arg('num_epoch', int, 200, "训练的轮数")
add_arg('num_conv_layers', int, 2, "卷积层数量")
add_arg('num_rnn_layers', int, 3, "循环神经网络的数量")
add_arg('rnn_layer_size', int, 1024, "循环神经网络的大小")
add_arg('learning_rate', float, 5e-4, "初始学习率")
add_arg('num_workers', int, 8, "读取数据的线程数量")
add_arg('min_duration', float, 0.5, "最短的用于训练的音频长度")
add_arg('max_duration', float, 20.0, "最长的用于训练的音频长度")
add_arg('resume_model', str, None, "恢复训练,当为None则不使用预训练模型")
add_arg('pretrained_model', str, None, "使用预训练模型的路径,当为None是不使用预训练模型")
add_arg('train_manifest', str, 'dataset/manifest.train', "训练的数据列表")
add_arg('test_manifest', str, 'dataset/manifest.test', "测试的数据列表")
add_arg('mean_istd_path', str, 'dataset/mean_istd.json', "均值和标准值得json文件路径,后缀 (.json)")
add_arg('vocab_dir', str, 'dataset/vocab_model', "生成的数据字典模型文件夹")
add_arg('output_model_dir', str, 'models/', "保存训练模型的文件夹")
add_arg('augment_conf_path', str, 'configs/augmentation.yml', "数据增强的配置文件,为yaml格式")
add_arg('metrics_type', str, 'cer', "评估所使用的错误率方法,有字错率(cer)、词错率(wer)", choices=['wer', 'cer'])
args = parser.parse_args()
print_arguments(args=args)
# 训练模型
def train():
# 是否使用GPU
if args.use_gpu:
assert paddle.is_compiled_with_cuda(), 'GPU不可用'
paddle.device.set_device("gpu")
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
paddle.device.set_device("cpu")
# 读取数据增强配置文件
with open(args.augment_conf_path, 'r', encoding='utf-8') as f:
data_augment_configs = yaml.load(f.read(), Loader=yaml.FullLoader)
print_arguments(configs=data_augment_configs, title='数据增强配置')
data_augment_configs = dict_to_object(data_augment_configs)
# 获取训练数据
audio_featurizer = AudioFeaturizer(mode="train")
tokenizer = Tokenizer(args.vocab_dir)
train_dataset = CustomDataset(data_manifest=args.train_manifest,
audio_featurizer=audio_featurizer,
tokenizer=tokenizer,
min_duration=args.min_duration,
max_duration=args.max_duration,
aug_conf=data_augment_configs,
mode="train")
train_batch_sampler = SortagradBatchSampler(train_dataset,
batch_size=args.batch_size,
sortagrad=True,
drop_last=True,
shuffle=True)
train_loader = DataLoader(dataset=train_dataset,
collate_fn=collate_fn,
batch_sampler=train_batch_sampler,
num_workers=args.num_workers)
test_dataset = CustomDataset(data_manifest=args.test_manifest,
audio_featurizer=audio_featurizer,
tokenizer=tokenizer,
min_duration=args.min_duration,
max_duration=args.max_duration,
aug_conf=data_augment_configs,
mode="eval")
test_loader = DataLoader(dataset=test_dataset,
collate_fn=collate_fn,
batch_size=args.batch_size,
num_workers=args.num_workers)
model = DeepSpeech2Model(input_dim=train_dataset.feature_dim,
vocab_size=train_dataset.vocab_size,
mean_istd_path=args.mean_istd_path,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size)
input_data = [paddle.rand((1, 100, train_dataset.feature_dim)), paddle.to_tensor([100], dtype='int64')]
summary(model, inputs=input_data)
scheduler = WarmupLR(learning_rate=args.learning_rate)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(),
learning_rate=scheduler,
weight_decay=5e-4,
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=400.0))
if args.pretrained_model:
model = load_pretrained(args.pretrained_model, model)
last_epoch, error_rate = 0, 1
if args.resume_model:
model, optimizer, last_epoch, error_rate = load_checkpoint(args.resume_model, model, optimizer)
model.train()
writer = LogWriter(logdir='log')
ctc_loss = paddle.nn.CTCLoss(blank=tokenizer.blank_id, reduction='sum')
train_step = 0
best_error_result = 1
max_step = len(train_loader) * (args.num_epoch - last_epoch)
for epoch_id in range(last_epoch, args.num_epoch):
train_times, reader_times, batch_times, loss_sum = [], [], [], []
start = time.time()
start_epoch = time.time()
for batch_id, batch in enumerate(train_loader()):
inputs, labels, input_lens, label_lens = batch
output, output_lens = model(inputs, input_lens)
loss = ctc_loss(output, labels, output_lens, label_lens)
loss = loss / output.shape[1]
loss.backward()
optimizer.step()
optimizer.clear_grad()
scheduler.step()
loss_sum.append(float(loss))
train_times.append((time.time() - start) * 1000)
# 记录学习率
writer.add_scalar('Train/lr', scheduler.get_lr(), train_step)
writer.add_scalar('Train/Loss', float(loss), train_step)
train_step += 1
# 多卡训练只使用一个进程打印
if batch_id % 100 == 0:
# 计算剩余时间
train_eta_sec = (sum(train_times) / len(train_times)) * (max_step - train_step) / 1000
eta_str = str(timedelta(seconds=int(train_eta_sec)))
train_loss = sum(loss_sum) / len(loss_sum)
logger.info(f'Train epoch: [{epoch_id + 1}/{args.num_epoch}], '
f'batch: [{batch_id}/{len(train_loader)}], '
f'loss: {train_loss:.5f}, '
f'learning_rate: {scheduler.get_lr():>.8f}, '
f'eta: {eta_str}')
train_times, reader_times, batch_times, loss_sum = [], [], [], []
start = time.time()
# 训练一个epoch消耗时间
train_time_str = str(timedelta(seconds=int(time.time() - start_epoch)))
# 评估模型
error_result = evaluate(model, test_loader, tokenizer)
writer.add_scalar(f'Test/{args.metrics_type}', error_result, epoch_id)
logger.info(f'Test epoch: {epoch_id + 1},训练耗时:{train_time_str}, {args.metrics_type}: {error_result}')
# 保存模型
save_checkpoint(model, optimizer, epoch_id + 1, save_model_path=args.output_model_dir,
error_rate=error_result, metrics_type=args.metrics_type)
# 保存最优模型
if error_result < best_error_result:
best_error_result = error_result
save_checkpoint(model, optimizer, epoch_id + 1, save_model_path=args.output_model_dir,
error_rate=error_result, metrics_type=args.metrics_type, best_model=True)
def evaluate(model, test_loader, tokenizer):
model.eval()
error_results = []
with paddle.no_grad():
for batch_id, batch in enumerate(tqdm(test_loader())):
inputs, labels, input_lens, label_lens = batch
ctc_probs, ctc_lens = model.predict(inputs, input_lens)
ctc_probs, ctc_lens = ctc_probs.numpy(), ctc_lens.numpy()
out_tokens = ctc_greedy_search_batch(ctc_probs=ctc_probs, ctc_lens=ctc_lens, blank_id=tokenizer.blank_id)
out_strings = tokenizer.ids2text([t for t in out_tokens])
labels = labels.numpy().tolist()
# 移除每条数据的-1值
labels = [list(filter(lambda x: x != -1, label)) for label in labels]
labels_str = tokenizer.ids2text(labels)
for out_string, label in zip(*(out_strings, labels_str)):
# 计算字错率或者词错率
if args.metrics_type == 'wer':
error_rate = wer(label, out_string)
else:
error_rate = cer(label, out_string)
error_results.append(error_rate)
error_result = float(sum(error_results) / len(error_results)) if len(error_results) > 0 else -1
model.train()
return error_result
if __name__ == '__main__':
train()