forked from Plachtaa/VITS-fast-fine-tuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwhisper_transcribe.py
90 lines (81 loc) · 3.34 KB
/
whisper_transcribe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import whisper
import os
import torchaudio
import json
lang2token = {
'zh': "[ZH]",
'ja': "[JA]",
"en": "[EN]",
}
def transcribe_one(audio_path):
# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# detect the spoken language
_, probs = model.detect_language(mel)
print(f"Detected language: {max(probs, key=probs.get)}")
lang = max(probs, key=probs.get)
# decode the audio
options = whisper.DecodingOptions()
result = whisper.decode(model, mel, options)
# print the recognized text
print(result.text)
return lang, result.text
if __name__ == "__main__":
model = whisper.load_model("medium")
parent_dir = "./custom_character_voice/"
speaker_names = list(os.walk(parent_dir))[0][1]
speaker2id = {}
speaker_annos = []
# resample audios
for speaker in speaker_names:
speaker2id[speaker] = 1000 + len(speaker2id)
for i, wavfile in enumerate(list(os.walk(parent_dir + speaker))[0][2]):
# try to load file as audio
if wavfile.startswith("processed_"):
continue
try:
wav, sr = torchaudio.load(parent_dir + speaker + "/" + wavfile, frame_offset=0, num_frames=-1, normalize=True,
channels_first=True)
wav = wav.mean(dim=0).unsqueeze(0)
if sr != 22050:
wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=22050)(wav)
if wav.shape[1] / sr > 20:
print(f"{wavfile} too long, ignoring\n")
save_path = parent_dir + speaker + "/" + f"processed_{i}.wav"
torchaudio.save(save_path, wav, 22050, channels_first=True)
# transcribe text
lang, text = transcribe_one(save_path)
if lang not in ['zh', 'en', 'ja']:
print(f"{lang} not supported, ignoring\n")
text = lang2token[lang] + text + lang2token[lang] + "\n"
speaker_annos.append(save_path + "|" + str(speaker2id[speaker]) + "|" + text)
except:
continue
# clean annotation
import text
cleaned_speaker_annos = []
for i, line in enumerate(speaker_annos):
path, sid, txt = line.split("|")
if len(txt) > 100:
continue
cleaned_text = text._clean_text(txt, ["cjke_cleaners2"])
cleaned_text += "\n" if not cleaned_text.endswith("\n") else ""
cleaned_speaker_annos.append(path + "|" + sid + "|" + cleaned_text)
with open("custom_character_anno.txt", 'w', encoding='utf-8') as f:
for line in cleaned_speaker_annos:
f.write(line)
# generate new config
with open("./configs/finetune_speaker.json", 'r', encoding='utf-8') as f:
hps = json.load(f)
# modify n_speakers
hps['data']["n_speakers"] = 1000 + len(speaker2id)
# add speaker names
for speaker in speaker_names:
hps['speakers'][speaker] = speaker2id[speaker]
# save modified config
with open("./configs/modified_finetune_speaker.json", 'w', encoding='utf-8') as f:
json.dump(hps, f, indent=2)
print("finished")