@@ -2445,40 +2445,50 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2445
2445
}
2446
2446
}
2447
2447
2448
- static void log_mel_spectrogram_worker_thread (int ith, const std::vector<float > &hann, const float *samples,
2449
- int n_samples, int fft_size, int fft_step, int n_threads,
2450
- const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
2451
- std::vector<float > fft_in (fft_size, 0.0 );
2452
- std::vector<float > fft_out (2 * fft_size);
2453
- int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2 );
2454
-
2455
- for (int i = ith; i < mel.n_len ; i += n_threads) {
2456
- const int offset = i * fft_step;
2457
-
2458
- // apply Hanning window
2459
- for (int j = 0 ; j < fft_size; j++) {
2460
- if (offset + j < n_samples) {
2461
- fft_in[j] = hann[j] * samples[offset + j];
2462
- } else {
2463
- fft_in[j] = 0.0 ;
2464
- }
2465
- }
2448
+ static bool hann_window (int length, bool periodic, std::vector<float > & output) {
2449
+ if (output.size () < length) {
2450
+ output.resize (length);
2451
+ }
2452
+ int offset = -1 ;
2453
+ if (periodic) {
2454
+ offset = 0 ;
2455
+ }
2456
+ for (int i = 0 ; i < length; i++) {
2457
+ output[i] = 0.5 *(1.0 - cosf ((2.0 *M_PI*i)/(length + offset)));
2458
+ }
2466
2459
2467
- // FFT -> mag^2
2468
- fft (fft_in, fft_out);
2460
+ return true ;
2461
+ }
2469
2462
2470
- for (int j = 0 ; j < fft_size; j++) {
2471
- fft_out[j] = (fft_out[2 * j + 0 ] * fft_out[2 * j + 0 ] + fft_out[2 * j + 1 ] * fft_out[2 * j + 1 ]);
2463
+ static void log_mel_spectrogram_worker_thread (int ith, const std::vector<float > & hann, const std::vector<float > & samples,
2464
+ int n_samples, int frame_size, int frame_step, int n_threads,
2465
+ const whisper_filters & filters, whisper_mel & mel) {
2466
+ std::vector<float > fft_in (frame_size, 0.0 );
2467
+ std::vector<float > fft_out (2 * frame_step);
2468
+ // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2469
+ int n_fft = 1 + (frame_size / 2 );
2470
+ int i = ith;
2471
+
2472
+ // calculate FFT only when fft_in are not all zero
2473
+ for (; i < std::min (n_samples / frame_step + 1 , mel.n_len ); i += n_threads) {
2474
+ const int offset = i * frame_step;
2475
+
2476
+ // apply Hanning window (~10% faster)
2477
+ for (int j = 0 ; j < std::min (frame_size, n_samples - offset); j++) {
2478
+ fft_in[j] = hann[j] * samples[offset + j];
2472
2479
}
2473
- for (int j = 1 ; j < fft_size / 2 ; j++) {
2474
- fft_out[j] += fft_out[fft_size - j];
2480
+ // fill the rest with zeros
2481
+ if (n_samples - offset < frame_size) {
2482
+ std::fill (fft_in.begin () + (n_samples - offset), fft_in.end (), 0.0 );
2475
2483
}
2476
2484
2477
- if (speed_up) {
2478
- // scale down in the frequency domain results in a speed up in the time domain
2479
- for (int j = 0 ; j < n_fft; j++) {
2480
- fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1 ]);
2481
- }
2485
+ // FFT
2486
+ fft (fft_in, fft_out);
2487
+
2488
+ // Calculate modulus^2 of complex numbers
2489
+ // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
2490
+ for (int j = 0 ; j < frame_size; j++) {
2491
+ fft_out[j] = (fft_out[2 * j + 0 ] * fft_out[2 * j + 0 ] + fft_out[2 * j + 1 ] * fft_out[2 * j + 1 ]);
2482
2492
}
2483
2493
2484
2494
// mel spectrogram
@@ -2489,10 +2499,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2489
2499
int k = 0 ;
2490
2500
for (k = 0 ; k < n_fft - 3 ; k += 4 ) {
2491
2501
sum +=
2492
- fft_out[k + 0 ] * filters.data [j* n_fft + k + 0 ] +
2493
- fft_out[k + 1 ] * filters.data [j* n_fft + k + 1 ] +
2494
- fft_out[k + 2 ] * filters.data [j* n_fft + k + 2 ] +
2495
- fft_out[k + 3 ] * filters.data [j* n_fft + k + 3 ];
2502
+ fft_out[k + 0 ] * filters.data [j * n_fft + k + 0 ] +
2503
+ fft_out[k + 1 ] * filters.data [j * n_fft + k + 1 ] +
2504
+ fft_out[k + 2 ] * filters.data [j * n_fft + k + 2 ] +
2505
+ fft_out[k + 3 ] * filters.data [j * n_fft + k + 3 ];
2496
2506
}
2497
2507
2498
2508
// handle n_fft remainder
@@ -2505,68 +2515,73 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2505
2515
mel.data [j * mel.n_len + i] = sum;
2506
2516
}
2507
2517
}
2518
+
2519
+ // Otherwise fft_out are all zero
2520
+ double sum = log10 (1e-10 );
2521
+ for (; i < mel.n_len ; i += n_threads) {
2522
+ for (int j = 0 ; j < mel.n_mel ; j++) {
2523
+ mel.data [j * mel.n_len + i] = sum;
2524
+ }
2525
+ }
2508
2526
}
2509
2527
2510
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2528
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
2511
2529
static bool log_mel_spectrogram (
2512
- whisper_state & wstate,
2513
- const float * samples,
2530
+ whisper_state & wstate,
2531
+ const float * samples,
2514
2532
const int n_samples,
2515
2533
const int /* sample_rate*/ ,
2516
- const int fft_size ,
2517
- const int fft_step ,
2534
+ const int frame_size ,
2535
+ const int frame_step ,
2518
2536
const int n_mel,
2519
2537
const int n_threads,
2520
- const whisper_filters & filters,
2521
- const bool speed_up ,
2522
- whisper_mel & mel) {
2538
+ const whisper_filters & filters,
2539
+ const bool debug ,
2540
+ whisper_mel & mel) {
2523
2541
const int64_t t_start_us = ggml_time_us ();
2524
2542
2525
- // Hanning window
2543
+ // Hanning window (Use cosf to eliminate difference)
2544
+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2545
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2526
2546
std::vector<float > hann;
2527
- hann.resize (fft_size);
2528
- for (int i = 0 ; i < fft_size; i++) {
2529
- hann[i] = 0.5 *(1.0 - cos ((2.0 *M_PI*i)/(fft_size)));
2530
- }
2531
-
2532
- mel.n_mel = n_mel;
2533
- mel.n_len = n_samples/fft_step;
2534
- mel.n_len_org = mel.n_len ;
2547
+ hann_window (frame_size, true , hann);
2535
2548
2536
- std::vector<float > samples_padded;
2537
2549
2538
- // pad audio with at least one extra chunk of zeros
2539
- {
2540
- const int pad = ( 100 *WHISPER_CHUNK_SIZE)/ 2 ;
2550
+ // Calculate the length of padding
2551
+ int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30 ;
2552
+ int64_t stage_2_pad = frame_size / 2 ;
2541
2553
2542
- if (mel. n_len % pad != 0 ) {
2543
- mel. n_len = (mel. n_len /pad + 1 )*pad ;
2544
- }
2545
- mel. n_len += pad ;
2554
+ // Initialize a vector and copy data from C array to it.
2555
+ std::vector< float > samples_padded ;
2556
+ samples_padded. resize (n_samples + stage_1_pad + stage_2_pad * 2 );
2557
+ std::copy (samples, samples + n_samples, samples_padded. begin () + stage_2_pad) ;
2546
2558
2547
- samples_padded.resize (mel.n_len *fft_step);
2548
- memcpy (samples_padded.data (), samples, n_samples*sizeof (float ));
2549
- memset (samples_padded.data () + n_samples, 0 , (mel.n_len *fft_step - n_samples)*sizeof (float ));
2559
+ // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
2560
+ std::fill (samples_padded.begin () + n_samples + stage_2_pad, samples_padded.begin () + n_samples + stage_1_pad + 2 * stage_2_pad, 0 );
2550
2561
2551
- samples = samples_padded. data ();
2552
- }
2562
+ // reflective pad 200 samples at the beginning of audio
2563
+ std::reverse_copy (samples + 1 , samples + 1 + stage_2_pad, samples_padded. begin ());
2553
2564
2554
- mel.data .resize (mel.n_mel *mel.n_len );
2565
+ mel.n_mel = n_mel;
2566
+ // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
2567
+ // Calculate number of frames + remove the last frame
2568
+ mel.n_len = (samples_padded.size () - frame_size) / frame_step;
2569
+ // Calculate semi-padded sample length to ensure compatibility
2570
+ mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
2571
+ mel.data .resize (mel.n_mel * mel.n_len );
2555
2572
2556
- // printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2557
- // printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
2558
2573
2559
2574
{
2560
2575
std::vector<std::thread> workers (n_threads - 1 );
2561
2576
for (int iw = 0 ; iw < n_threads - 1 ; ++iw) {
2562
2577
workers[iw] = std::thread (
2563
- log_mel_spectrogram_worker_thread, iw + 1 , std::cref (hann), samples ,
2564
- n_samples, fft_size, fft_step , n_threads,
2565
- std::cref (filters), speed_up, std::ref (mel));
2578
+ log_mel_spectrogram_worker_thread, iw + 1 , std::cref (hann), samples_padded ,
2579
+ n_samples + stage_2_pad, frame_size, frame_step , n_threads,
2580
+ std::cref (filters), std::ref (mel));
2566
2581
}
2567
2582
2568
2583
// main thread
2569
- log_mel_spectrogram_worker_thread (0 , hann, samples , n_samples, fft_size, fft_step , n_threads, filters, speed_up , mel);
2584
+ log_mel_spectrogram_worker_thread (0 , hann, samples_padded , n_samples + stage_2_pad, frame_size, frame_step , n_threads, filters, mel);
2570
2585
2571
2586
for (int iw = 0 ; iw < n_threads - 1 ; ++iw) {
2572
2587
workers[iw].join ();
@@ -2580,7 +2595,6 @@ static bool log_mel_spectrogram(
2580
2595
mmax = mel.data [i];
2581
2596
}
2582
2597
}
2583
- // printf("%s: max = %f\n", __func__, mmax);
2584
2598
2585
2599
mmax -= 8.0 ;
2586
2600
@@ -2594,7 +2608,16 @@ static bool log_mel_spectrogram(
2594
2608
2595
2609
wstate.t_mel_us += ggml_time_us () - t_start_us;
2596
2610
2597
- // printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step);
2611
+ // Dump log_mel_spectrogram
2612
+ if (debug) {
2613
+ std::ofstream outFile (" log_mel_spectrogram.json" );
2614
+ outFile << " [" ;
2615
+ for (uint64_t i = 0 ; i < mel.data .size () - 1 ; i++) {
2616
+ outFile << mel.data [i] << " , " ;
2617
+ }
2618
+ outFile << mel.data [mel.data .size () - 1 ] << " ]" ;
2619
+ outFile.close ();
2620
+ }
2598
2621
2599
2622
return true ;
2600
2623
}
@@ -3026,21 +3049,30 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3026
3049
return whisper_pcm_to_mel_with_state (ctx, ctx->state , samples, n_samples, n_threads);
3027
3050
}
3028
3051
3029
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
3052
+ // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3030
3053
int whisper_pcm_to_mel_phase_vocoder_with_state (struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3031
- if (!log_mel_spectrogram (*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model .filters , true , state->mel )) {
3054
+ if (!log_mel_spectrogram (*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model .filters , false , state->mel )) {
3032
3055
log (" %s: failed to compute mel spectrogram\n " , __func__);
3033
3056
return -1 ;
3034
3057
}
3035
3058
3036
3059
return 0 ;
3037
3060
}
3038
3061
3039
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
3062
+ // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3040
3063
int whisper_pcm_to_mel_phase_vocoder (struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
3041
3064
return whisper_pcm_to_mel_phase_vocoder_with_state (ctx, ctx->state , samples, n_samples, n_threads);
3042
3065
}
3043
3066
3067
+ // same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
3068
+ // TODO
3069
+
3070
+ // same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
3071
+ // TODO
3072
+
3073
+ // same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
3074
+ // TODO
3075
+
3044
3076
int whisper_set_mel_with_state (
3045
3077
struct whisper_context * /* ctx*/ ,
3046
3078
struct whisper_state * state,
@@ -3492,6 +3524,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3492
3524
/* .max_tokens =*/ 0 ,
3493
3525
3494
3526
/* .speed_up =*/ false ,
3527
+ /* .debug_mode =*/ false ,
3495
3528
/* .audio_ctx =*/ 0 ,
3496
3529
3497
3530
/* .tdrz_enable =*/ false ,
@@ -3653,7 +3686,7 @@ static void whisper_process_logits(
3653
3686
WHISPER_ASSERT (n_logits == ctx.vocab .n_vocab );
3654
3687
3655
3688
// extract the logits for the last token
3656
- // we will be mutating and therefore we don't want to use the ctx.logits buffer directly
3689
+ // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly
3657
3690
auto & probs = decoder.probs ;
3658
3691
auto & logits = decoder.logits ;
3659
3692
auto & logprobs = decoder.logprobs ;
@@ -4056,10 +4089,9 @@ int whisper_full_with_state(
4056
4089
4057
4090
// compute log mel spectrogram
4058
4091
if (params.speed_up ) {
4059
- if (whisper_pcm_to_mel_phase_vocoder_with_state (ctx, state, samples, n_samples, params.n_threads ) != 0 ) {
4060
- log (" %s: failed to compute log mel spectrogram\n " , __func__);
4061
- return -1 ;
4062
- }
4092
+ // TODO: Replace PV with more advanced algorithm
4093
+ log (" %s: failed to compute log mel spectrogram\n " , __func__);
4094
+ return -1 ;
4063
4095
} else {
4064
4096
if (whisper_pcm_to_mel_with_state (ctx, state, samples, n_samples, params.n_threads ) != 0 ) {
4065
4097
log (" %s: failed to compute log mel spectrogram\n " , __func__);
@@ -4095,8 +4127,8 @@ int whisper_full_with_state(
4095
4127
const int seek_start = params.offset_ms /10 ;
4096
4128
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state (state) : seek_start + params.duration_ms /10 ;
4097
4129
4098
- // if length of spectrogram is less than 1s (100 samples ), then return
4099
- // basically don't process anything that is less than 1s
4130
+ // if length of spectrogram is less than 1.0s (100 frames ), then return
4131
+ // basically don't process anything that is less than 1.0s
4100
4132
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
4101
4133
if (seek_end < seek_start + (params.speed_up ? 50 : 100 )) {
4102
4134
return 0 ;
0 commit comments