Skip to content

Commit 7e54df4

Browse files
bobqianicggerganov
andauthored
whisper : significantly improve the inference quality (ggerganov#1148)
* Fix MSVC compile error C3688 Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC. * Significantly improve inference quality In the function `log_mel_spectrogram_worker_thread`, there's an array out-of-bounds issue occurring during the calculation of complex number moduli. This issue is causing disruptions in the FFT spectrum, which, in turn, is reducing the quality of inference. * Significantly improve inference quality At last, I've pinpointed the actual source of the problem. Given that the frequency spectrum generated from real input data is symmetrical around the Nyquist frequency, there's a for-loop within the `log_mel_spectrogram_worker_thread` function that attempts to fold the frequency spectrum. Regrettably, a bug within this for-loop is causing a frame shift in the frequency spectrum. The previous attempt to remedy this, which involved using `fft_size + 1` when calculating the modulus, was merely a band-aid solution and did not address the underlying issue. * Addressed a few minor issues Fixed the issue of `fft_out` continuously expanding. Resolved the fallback caused by using 'break' instead of `fft_in[j] = 0`. * Significantly improve inference quality Thanks for your patience everyone. It's finally sorted out. Now, the right side of the FFT spectrum is being flipped over to the left, and the amplitudes at corresponding positions on the left and right are added together (the spectrum on the left needs to be shifted by one position), then the average is calculated. FFT_OUT[0] is no longer discarded, making full use of the limited space to pack in more information. * Add annotation and performance improvement * Calculate FFT only when fft_in are not all zero * Some minor performance improvement * Fixed a bug impacting inference quality * The first version after all the analysis is completed. * Fix some bugs and add debug mode * Fixed several bugs * Temporarily disable speed-up mode and add debug mode. * Add debug mode * Disable speed-up mode and add debug mode * Fix CI error (ggerganov#1) * Fix error * Fix error * Fixed several bugs including [BLANK_AUDIO] problem * Remove Hard-coded hann window * Some Final Fix (ggerganov#2) * Fix error * Fix error * Probably the last commit * Probably the last commit * whisper : minor coding style changes * whisper : remove debug from public API --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 20a8097 commit 7e54df4

File tree

3 files changed

+121
-84
lines changed

3 files changed

+121
-84
lines changed

examples/main/main.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct whisper_params {
7070
float logprob_thold = -1.00f;
7171

7272
bool speed_up = false;
73+
bool debug_mode = false;
7374
bool translate = false;
7475
bool detect_language = false;
7576
bool diarize = false;
@@ -135,7 +136,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
135136
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
136137
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
137138
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
138-
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
139+
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
140+
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
139141
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
140142
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
141143
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
@@ -190,7 +192,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
190192
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
191193
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
192194
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
193-
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
195+
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
196+
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
194197
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
195198
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
196199
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
@@ -915,6 +918,7 @@ int main(int argc, char ** argv) {
915918
wparams.split_on_word = params.split_on_word;
916919

917920
wparams.speed_up = params.speed_up;
921+
wparams.debug_mode = params.debug_mode;
918922

919923
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
920924

whisper.cpp

+114-82
Original file line numberDiff line numberDiff line change
@@ -2445,40 +2445,50 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
24452445
}
24462446
}
24472447

2448-
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> &hann, const float *samples,
2449-
int n_samples, int fft_size, int fft_step, int n_threads,
2450-
const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
2451-
std::vector<float> fft_in(fft_size, 0.0);
2452-
std::vector<float> fft_out(2 * fft_size);
2453-
int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
2454-
2455-
for (int i = ith; i < mel.n_len; i += n_threads) {
2456-
const int offset = i * fft_step;
2457-
2458-
// apply Hanning window
2459-
for (int j = 0; j < fft_size; j++) {
2460-
if (offset + j < n_samples) {
2461-
fft_in[j] = hann[j] * samples[offset + j];
2462-
} else {
2463-
fft_in[j] = 0.0;
2464-
}
2465-
}
2448+
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2449+
if (output.size() < length) {
2450+
output.resize(length);
2451+
}
2452+
int offset = -1;
2453+
if (periodic) {
2454+
offset = 0;
2455+
}
2456+
for (int i = 0; i < length; i++) {
2457+
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
2458+
}
24662459

2467-
// FFT -> mag^2
2468-
fft(fft_in, fft_out);
2460+
return true;
2461+
}
24692462

2470-
for (int j = 0; j < fft_size; j++) {
2471-
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
2463+
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
2464+
int n_samples, int frame_size, int frame_step, int n_threads,
2465+
const whisper_filters & filters, whisper_mel & mel) {
2466+
std::vector<float> fft_in(frame_size, 0.0);
2467+
std::vector<float> fft_out(2 * frame_step);
2468+
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2469+
int n_fft = 1 + (frame_size / 2);
2470+
int i = ith;
2471+
2472+
// calculate FFT only when fft_in are not all zero
2473+
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2474+
const int offset = i * frame_step;
2475+
2476+
// apply Hanning window (~10% faster)
2477+
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
2478+
fft_in[j] = hann[j] * samples[offset + j];
24722479
}
2473-
for (int j = 1; j < fft_size / 2; j++) {
2474-
fft_out[j] += fft_out[fft_size - j];
2480+
// fill the rest with zeros
2481+
if (n_samples - offset < frame_size) {
2482+
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
24752483
}
24762484

2477-
if (speed_up) {
2478-
// scale down in the frequency domain results in a speed up in the time domain
2479-
for (int j = 0; j < n_fft; j++) {
2480-
fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
2481-
}
2485+
// FFT
2486+
fft(fft_in, fft_out);
2487+
2488+
// Calculate modulus^2 of complex numbers
2489+
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
2490+
for (int j = 0; j < frame_size; j++) {
2491+
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
24822492
}
24832493

24842494
// mel spectrogram
@@ -2489,10 +2499,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
24892499
int k = 0;
24902500
for (k = 0; k < n_fft - 3; k += 4) {
24912501
sum +=
2492-
fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
2493-
fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
2494-
fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
2495-
fft_out[k + 3] * filters.data[j*n_fft + k + 3];
2502+
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
2503+
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
2504+
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
2505+
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
24962506
}
24972507

24982508
// handle n_fft remainder
@@ -2505,68 +2515,73 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
25052515
mel.data[j * mel.n_len + i] = sum;
25062516
}
25072517
}
2518+
2519+
// Otherwise fft_out are all zero
2520+
double sum = log10(1e-10);
2521+
for (; i < mel.n_len; i += n_threads) {
2522+
for (int j = 0; j < mel.n_mel; j++) {
2523+
mel.data[j * mel.n_len + i] = sum;
2524+
}
2525+
}
25082526
}
25092527

2510-
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2528+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
25112529
static bool log_mel_spectrogram(
2512-
whisper_state & wstate,
2513-
const float * samples,
2530+
whisper_state & wstate,
2531+
const float * samples,
25142532
const int n_samples,
25152533
const int /*sample_rate*/,
2516-
const int fft_size,
2517-
const int fft_step,
2534+
const int frame_size,
2535+
const int frame_step,
25182536
const int n_mel,
25192537
const int n_threads,
2520-
const whisper_filters & filters,
2521-
const bool speed_up,
2522-
whisper_mel & mel) {
2538+
const whisper_filters & filters,
2539+
const bool debug,
2540+
whisper_mel & mel) {
25232541
const int64_t t_start_us = ggml_time_us();
25242542

2525-
// Hanning window
2543+
// Hanning window (Use cosf to eliminate difference)
2544+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2545+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
25262546
std::vector<float> hann;
2527-
hann.resize(fft_size);
2528-
for (int i = 0; i < fft_size; i++) {
2529-
hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
2530-
}
2531-
2532-
mel.n_mel = n_mel;
2533-
mel.n_len = n_samples/fft_step;
2534-
mel.n_len_org = mel.n_len;
2547+
hann_window(frame_size, true, hann);
25352548

2536-
std::vector<float> samples_padded;
25372549

2538-
// pad audio with at least one extra chunk of zeros
2539-
{
2540-
const int pad = (100*WHISPER_CHUNK_SIZE)/2;
2550+
// Calculate the length of padding
2551+
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
2552+
int64_t stage_2_pad = frame_size / 2;
25412553

2542-
if (mel.n_len % pad != 0) {
2543-
mel.n_len = (mel.n_len/pad + 1)*pad;
2544-
}
2545-
mel.n_len += pad;
2554+
// Initialize a vector and copy data from C array to it.
2555+
std::vector<float> samples_padded;
2556+
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
2557+
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
25462558

2547-
samples_padded.resize(mel.n_len*fft_step);
2548-
memcpy(samples_padded.data(), samples, n_samples*sizeof(float));
2549-
memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float));
2559+
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
2560+
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
25502561

2551-
samples = samples_padded.data();
2552-
}
2562+
// reflective pad 200 samples at the beginning of audio
2563+
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
25532564

2554-
mel.data.resize(mel.n_mel*mel.n_len);
2565+
mel.n_mel = n_mel;
2566+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
2567+
// Calculate number of frames + remove the last frame
2568+
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
2569+
// Calculate semi-padded sample length to ensure compatibility
2570+
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
2571+
mel.data.resize(mel.n_mel * mel.n_len);
25552572

2556-
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2557-
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
25582573

25592574
{
25602575
std::vector<std::thread> workers(n_threads - 1);
25612576
for (int iw = 0; iw < n_threads - 1; ++iw) {
25622577
workers[iw] = std::thread(
2563-
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
2564-
n_samples, fft_size, fft_step, n_threads,
2565-
std::cref(filters), speed_up, std::ref(mel));
2578+
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
2579+
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
2580+
std::cref(filters), std::ref(mel));
25662581
}
25672582

25682583
// main thread
2569-
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
2584+
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
25702585

25712586
for (int iw = 0; iw < n_threads - 1; ++iw) {
25722587
workers[iw].join();
@@ -2580,7 +2595,6 @@ static bool log_mel_spectrogram(
25802595
mmax = mel.data[i];
25812596
}
25822597
}
2583-
//printf("%s: max = %f\n", __func__, mmax);
25842598

25852599
mmax -= 8.0;
25862600

@@ -2594,7 +2608,16 @@ static bool log_mel_spectrogram(
25942608

25952609
wstate.t_mel_us += ggml_time_us() - t_start_us;
25962610

2597-
//printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step);
2611+
// Dump log_mel_spectrogram
2612+
if (debug) {
2613+
std::ofstream outFile("log_mel_spectrogram.json");
2614+
outFile << "[";
2615+
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
2616+
outFile << mel.data[i] << ", ";
2617+
}
2618+
outFile << mel.data[mel.data.size() - 1] << "]";
2619+
outFile.close();
2620+
}
25982621

25992622
return true;
26002623
}
@@ -3026,21 +3049,30 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
30263049
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
30273050
}
30283051

3029-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
3052+
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
30303053
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3031-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
3054+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
30323055
log("%s: failed to compute mel spectrogram\n", __func__);
30333056
return -1;
30343057
}
30353058

30363059
return 0;
30373060
}
30383061

3039-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
3062+
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
30403063
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
30413064
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
30423065
}
30433066

3067+
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
3068+
// TODO
3069+
3070+
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
3071+
// TODO
3072+
3073+
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
3074+
// TODO
3075+
30443076
int whisper_set_mel_with_state(
30453077
struct whisper_context * /*ctx*/,
30463078
struct whisper_state * state,
@@ -3492,6 +3524,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
34923524
/*.max_tokens =*/ 0,
34933525

34943526
/*.speed_up =*/ false,
3527+
/*.debug_mode =*/ false,
34953528
/*.audio_ctx =*/ 0,
34963529

34973530
/*.tdrz_enable =*/ false,
@@ -3653,7 +3686,7 @@ static void whisper_process_logits(
36533686
WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
36543687

36553688
// extract the logits for the last token
3656-
// we will be mutating and therefore we don't want to use the ctx.logits buffer directly
3689+
// we will be mutating, and therefore we don't want to use the ctx.logits buffer directly
36573690
auto & probs = decoder.probs;
36583691
auto & logits = decoder.logits;
36593692
auto & logprobs = decoder.logprobs;
@@ -4056,10 +4089,9 @@ int whisper_full_with_state(
40564089

40574090
// compute log mel spectrogram
40584091
if (params.speed_up) {
4059-
if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
4060-
log("%s: failed to compute log mel spectrogram\n", __func__);
4061-
return -1;
4062-
}
4092+
// TODO: Replace PV with more advanced algorithm
4093+
log("%s: failed to compute log mel spectrogram\n", __func__);
4094+
return -1;
40634095
} else {
40644096
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
40654097
log("%s: failed to compute log mel spectrogram\n", __func__);
@@ -4095,8 +4127,8 @@ int whisper_full_with_state(
40954127
const int seek_start = params.offset_ms/10;
40964128
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
40974129

4098-
// if length of spectrogram is less than 1s (100 samples), then return
4099-
// basically don't process anything that is less than 1s
4130+
// if length of spectrogram is less than 1.0s (100 frames), then return
4131+
// basically don't process anything that is less than 1.0s
41004132
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
41014133
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
41024134
return 0;

whisper.h

+1
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ extern "C" {
375375
// [EXPERIMENTAL] speed-up techniques
376376
// note: these can significantly reduce the quality of the output
377377
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
378+
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
378379
int audio_ctx; // overwrite the audio context size (0 = use default)
379380

380381
// [EXPERIMENTAL] [TDRZ] tinydiarize

0 commit comments

Comments
 (0)