-
Notifications
You must be signed in to change notification settings - Fork 148
/
Copy pathcreate_data.py
163 lines (144 loc) · 6.75 KB
/
create_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import argparse
import functools
import json
import os
from collections import Counter
from tqdm import tqdm
from yeaudio.audio import AudioSegment
from zhconv import convert
from data_utils.normalizer import FeatureNormalizer
from data_utils.tokenizer import Tokenizer
from utils.utils import add_arguments, print_arguments, read_manifest
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('annotation_path', str, 'dataset/annotation/', '标注文件的路径,如果annotation_path包含了test.txt,就全部使用test.txt的数据作为测试数据')
add_arg('manifest_prefix', str, 'dataset/', '训练数据清单,包括音频路径和标注信息')
add_arg('max_test_manifest', int, 10000, '最大的测试数据数量')
add_arg('count_threshold', int, 2, '字符计数的截断阈值,0为不做限制')
add_arg('vocab_dir', str, 'dataset/vocab_model', '生成的数据字典模型文件夹')
add_arg('vocab_model_type', str, 'unigram', '生成的数据字典模型类型,中文等字符类型的用char,其他的用unigram')
add_arg('vocab_size', int, 5000, '生成的数据字典的大小,如果vocab_model_type是char则无效')
add_arg('manifest_path', str, 'dataset/manifest.train', '数据列表路径')
add_arg('num_samples', int, 1000000, '用于计算均值和标准值得音频数量,当为-1使用全部数据')
add_arg('mean_istd_filepath', str, 'dataset/mean_istd.json', '均值和标准值得json文件路径,后缀 (.json)')
args = parser.parse_args()
# 创建数据列表
def create_manifest(annotation_path, manifest_path_prefix):
data_list = []
test_list = []
durations_all = []
duration_0_10 = 0
duration_10_20 = 0
duration_20 = 0
# 获取全部的标注文件
for annotation_text in os.listdir(annotation_path):
durations = []
print('正在创建%s的数据列表,请等待 ...' % annotation_text)
annotation_text_path = os.path.join(annotation_path, annotation_text)
# 读取标注文件
with open(annotation_text_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in tqdm(lines):
audio_path = line.split('\t')[0]
try:
# 过滤非法的字符
text = is_ustr(line.split('\t')[1].replace('\n', '').replace('\r', ''))
# 保证全部都是简体
text = convert(text, 'zh-cn')
# 获取音频的长度
audio = AudioSegment.from_file(audio_path)
duration = audio.duration
if duration <= 10:
duration_0_10 += 1
elif 10 < duration <= 20:
duration_10_20 += 1
else:
duration_20 += 1
durations.append(duration)
d = json.dumps(
{
'audio_filepath': audio_path.replace('\\', '/'),
'duration': duration,
'text': text
},
ensure_ascii=False)
if annotation_text == 'test.txt':
test_list.append(d)
else:
data_list.append(d)
except Exception as e:
print(e)
continue
durations_all.append(sum(durations))
print("%s数据一共[%d]小时!" % (annotation_text, int(sum(durations) / 3600)))
print("0-10秒的数量:%d,10-20秒的数量:%d,大于20秒的数量:%d" % (duration_0_10, duration_10_20, duration_20))
# 将音频的路径,长度和标签写入到数据列表中
f_train = open(os.path.join(manifest_path_prefix, 'manifest.train'), 'w', encoding='utf-8')
f_test = open(os.path.join(manifest_path_prefix, 'manifest.test'), 'w', encoding='utf-8')
for line in test_list:
f_test.write(line + '\n')
interval = 500
if len(data_list) / 500 > args.max_test_manifest:
interval = len(data_list) // args.max_test_manifest
for i, line in enumerate(data_list):
if i % interval == 0 and i != 0:
if len(test_list) == 0:
f_test.write(line + '\n')
else:
f_train.write(line + '\n')
else:
f_train.write(line + '\n')
f_train.close()
f_test.close()
print("创建数量列表完成,全部数据一共[%d]小时!" % int(sum(durations_all) / 3600))
# 过滤非文字的字符
def is_ustr(in_str):
out_str = ''
for i in range(len(in_str)):
if is_uchar(in_str[i]):
out_str = out_str + in_str[i]
else:
out_str = out_str + ' '
return ''.join(out_str.split())
# 判断是否为文字字符
def is_uchar(uchar):
if u'\u4e00' <= uchar <= u'\u9fa5':
return True
if u'\u0030' <= uchar <= u'\u0039':
return False
if (u'\u0041' <= uchar <= u'\u005a') or (u'\u0061' <= uchar <= u'\u007a'):
return False
if uchar in ('-', ',', '.', '>', '?'):
return False
return False
# 获取全部字符
def count_manifest(counter, manifest_path):
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
for char in line_json['text']:
counter.update(char)
# 计算数据集的均值和标准值
def compute_mean_std(manifest_path, num_samples, mean_istd_filepath):
# 随机取指定的数量计算平均值归一化
normalizer = FeatureNormalizer(mean_istd_filepath=mean_istd_filepath)
# 将计算的结果保存的文件中
normalizer.compute_mean_istd(manifest_path=manifest_path, num_samples=num_samples)
print(f'计算的均值和标准值已保存在 {mean_istd_filepath}!')
def main():
print_arguments(args)
print('开始生成数据列表...')
create_manifest(annotation_path=args.annotation_path,
manifest_path_prefix=args.manifest_prefix)
print('开始生成数据字典...')
tokenizer = Tokenizer(vocab_model_dir=args.vocab_dir,
model_type=args.vocab_model_type,
build_vocab_size=args.vocab_size,
is_build_vocab=True)
tokenizer.build_vocab(manifest_paths=[args.manifest_path])
print('数据词汇表已生成完成,保存与:%s' % args.vocab_dir)
print('='*70)
print('开始抽取%s条数据计算均值和标准值...' % args.num_samples)
compute_mean_std(args.manifest_path, args.num_samples, args.mean_istd_filepath)
print('='*70)
if __name__ == '__main__':
main()