9
9
#include " common-sdl.h"
10
10
#include " common.h"
11
11
#include " whisper.h"
12
+ #include " grammar-parser.h"
12
13
13
14
#include < sstream>
14
15
#include < cassert>
21
22
#include < vector>
22
23
#include < map>
23
24
25
+ bool file_exists (const std::string & fname) {
26
+ std::ifstream f (fname.c_str ());
27
+ return f.good ();
28
+ }
29
+
24
30
// command-line parameters
25
31
struct whisper_params {
26
32
int32_t n_threads = std::min(4 , (int32_t ) std::thread::hardware_concurrency());
@@ -30,8 +36,12 @@ struct whisper_params {
30
36
int32_t max_tokens = 32 ;
31
37
int32_t audio_ctx = 0 ;
32
38
33
- float vad_thold = 0 .6f ;
34
- float freq_thold = 100 .0f ;
39
+ float vad_thold = 0 .6f ;
40
+ float freq_thold = 100 .0f ;
41
+
42
+ float grammar_penalty = 100 .0f ;
43
+
44
+ grammar_parser::parse_state grammar_parsed;
35
45
36
46
bool speed_up = false ;
37
47
bool translate = false ;
@@ -45,6 +55,8 @@ struct whisper_params {
45
55
std::string fname_out;
46
56
std::string commands;
47
57
std::string prompt;
58
+ std::string context;
59
+ std::string grammar;
48
60
};
49
61
50
62
void whisper_print_usage (int argc, char ** argv, const whisper_params & params);
@@ -75,6 +87,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
75
87
else if (arg == " -f" || arg == " --file" ) { params.fname_out = argv[++i]; }
76
88
else if (arg == " -cmd" || arg == " --commands" ) { params.commands = argv[++i]; }
77
89
else if (arg == " -p" || arg == " --prompt" ) { params.prompt = argv[++i]; }
90
+ else if (arg == " -ctx" || arg == " --context" ) { params.context = argv[++i]; }
91
+ else if ( arg == " --grammar" ) { params.grammar = argv[++i]; }
92
+ else if ( arg == " --grammar-penalty" ) { params.grammar_penalty = std::stof (argv[++i]); }
78
93
else {
79
94
fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
80
95
whisper_print_usage (argc, argv, params);
@@ -109,36 +124,72 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
109
124
fprintf (stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n " , params.fname_out .c_str ());
110
125
fprintf (stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n " , params.commands .c_str ());
111
126
fprintf (stderr, " -p, --prompt [%-7s] the required activation prompt\n " , params.prompt .c_str ());
127
+ fprintf (stderr, " -ctx, --context [%-7s] sample text to help the transcription\n " , params.context .c_str ());
128
+ fprintf (stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n " , params.grammar .c_str ());
129
+ fprintf (stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n " , params.grammar_penalty );
112
130
fprintf (stderr, " \n " );
113
131
}
114
132
115
- std::string transcribe (whisper_context * ctx, const whisper_params & params, const std::vector<float > & pcmf32, float & prob, int64_t & t_ms) {
133
+ std::string transcribe (
134
+ whisper_context * ctx,
135
+ const whisper_params & params,
136
+ const std::vector<float > & pcmf32,
137
+ const std::string & grammar_rule,
138
+ float & logprob_min,
139
+ float & logprob_sum,
140
+ int & n_tokens,
141
+ int64_t & t_ms) {
116
142
const auto t_start = std::chrono::high_resolution_clock::now ();
117
143
118
- prob = 0 .0f ;
144
+ logprob_min = 0 .0f ;
145
+ logprob_sum = 0 .0f ;
146
+ n_tokens = 0 ;
119
147
t_ms = 0 ;
120
148
121
- whisper_full_params wparams = whisper_full_default_params (WHISPER_SAMPLING_GREEDY);
149
+ // whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
150
+ whisper_full_params wparams = whisper_full_default_params (WHISPER_SAMPLING_BEAM_SEARCH);
122
151
123
152
wparams.print_progress = false ;
124
153
wparams.print_special = params.print_special ;
125
154
wparams.print_realtime = false ;
126
155
wparams.print_timestamps = !params.no_timestamps ;
127
156
wparams.translate = params.translate ;
128
157
wparams.no_context = true ;
158
+ wparams.no_timestamps = params.no_timestamps ;
129
159
wparams.single_segment = true ;
130
160
wparams.max_tokens = params.max_tokens ;
131
161
wparams.language = params.language .c_str ();
132
162
wparams.n_threads = params.n_threads ;
133
163
134
- wparams.audio_ctx = params.audio_ctx ;
135
- wparams.speed_up = params.speed_up ;
164
+ wparams.audio_ctx = params.audio_ctx ;
165
+ wparams.speed_up = params.speed_up ;
166
+
167
+ wparams.temperature = 0 .4f ;
168
+ wparams.temperature_inc = 1 .0f ;
169
+ wparams.greedy .best_of = 5 ;
170
+
171
+ wparams.beam_search .beam_size = 5 ;
172
+
173
+ wparams.initial_prompt = params.context .data ();
174
+
175
+ const auto & grammar_parsed = params.grammar_parsed ;
176
+ auto grammar_rules = grammar_parsed.c_rules ();
177
+
178
+ if (!params.grammar_parsed .rules .empty () && !grammar_rule.empty ()) {
179
+ if (grammar_parsed.symbol_ids .find (grammar_rule) == grammar_parsed.symbol_ids .end ()) {
180
+ fprintf (stderr, " %s: warning: grammar rule '%s' not found - skipping grammar sampling\n " , __func__, grammar_rule.c_str ());
181
+ } else {
182
+ wparams.grammar_rules = grammar_rules.data ();
183
+ wparams.n_grammar_rules = grammar_rules.size ();
184
+ wparams.i_start_rule = grammar_parsed.symbol_ids .at (grammar_rule);
185
+ wparams.grammar_penalty = params.grammar_penalty ;
186
+ }
187
+ }
136
188
137
189
if (whisper_full (ctx, wparams, pcmf32.data (), pcmf32.size ()) != 0 ) {
138
190
return " " ;
139
191
}
140
192
141
- int prob_n = 0 ;
142
193
std::string result;
143
194
144
195
const int n_segments = whisper_full_n_segments (ctx);
@@ -147,19 +198,17 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
147
198
148
199
result += text;
149
200
150
- const int n_tokens = whisper_full_n_tokens (ctx, i);
151
- for (int j = 0 ; j < n_tokens ; ++j) {
201
+ const int n = whisper_full_n_tokens (ctx, i);
202
+ for (int j = 0 ; j < n ; ++j) {
152
203
const auto token = whisper_full_get_token_data (ctx, i, j);
153
204
154
- prob += token.p ;
155
- ++prob_n;
205
+ if (token.plog > 0 .0f ) exit (0 );
206
+ logprob_min = std::min (logprob_min, token.plog );
207
+ logprob_sum += token.plog ;
208
+ ++n_tokens;
156
209
}
157
210
}
158
211
159
- if (prob_n > 0 ) {
160
- prob /= prob_n;
161
- }
162
-
163
212
const auto t_end = std::chrono::high_resolution_clock::now ();
164
213
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count ();
165
214
@@ -250,7 +299,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
250
299
fprintf (stderr, " ]\n " );
251
300
}
252
301
253
- std::string k_prompt = " select one from the available words: " ;
302
+ std::string k_prompt = " select one from the available words: " ;
254
303
for (int i = 0 ; i < (int ) allowed_commands.size (); ++i) {
255
304
if (i > 0 ) {
256
305
k_prompt += " , " ;
@@ -418,7 +467,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
418
467
bool is_running = true ;
419
468
bool ask_prompt = true ;
420
469
421
- float prob = 0 .0f ;
470
+ float logprob_min = 0 .0f ;
471
+ float logprob_sum = 0 .0f ;
472
+ int n_tokens = 0 ;
422
473
423
474
std::vector<float > pcmf32_cur;
424
475
@@ -456,7 +507,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
456
507
// detect the commands
457
508
audio.get (params.command_ms , pcmf32_cur);
458
509
459
- const auto txt = ::trim (::transcribe (ctx, params, pcmf32_cur, prob , t_ms));
510
+ const auto txt = ::trim (::transcribe (ctx, params, pcmf32_cur, " " , logprob_min, logprob_sum, n_tokens , t_ms));
460
511
461
512
const auto words = get_words (txt);
462
513
@@ -492,18 +543,27 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
492
543
493
544
// general-purpose mode
494
545
// freely transcribe the voice into text
495
- int process_general_transcription (struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
546
+ int process_general_transcription (struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
496
547
bool is_running = true ;
497
548
bool have_prompt = false ;
498
549
bool ask_prompt = true ;
499
550
500
- float prob0 = 0 .0f ;
501
- float prob = 0 .0f ;
551
+ float logprob_min0 = 0 .0f ;
552
+ float logprob_min = 0 .0f ;
553
+
554
+ float logprob_sum0 = 0 .0f ;
555
+ float logprob_sum = 0 .0f ;
556
+
557
+ int n_tokens0 = 0 ;
558
+ int n_tokens = 0 ;
502
559
503
560
std::vector<float > pcmf32_cur;
504
561
std::vector<float > pcmf32_prompt;
505
562
506
- const std::string k_prompt = " Ok Whisper, start listening for commands." ;
563
+ std::string k_prompt = " Ok Whisper, start listening for commands." ;
564
+ if (!params.prompt .empty ()) {
565
+ k_prompt = params.prompt ;
566
+ }
507
567
508
568
fprintf (stderr, " \n " );
509
569
fprintf (stderr, " %s: general-purpose mode\n " , __func__);
@@ -536,9 +596,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
536
596
// wait for activation phrase
537
597
audio.get (params.prompt_ms , pcmf32_cur);
538
598
539
- const auto txt = ::trim (::transcribe (ctx, params, pcmf32_cur, prob0 , t_ms));
599
+ const auto txt = ::trim (::transcribe (ctx, params, pcmf32_cur, " prompt " , logprob_min0, logprob_sum0, n_tokens0 , t_ms));
540
600
541
- fprintf (stdout, " %s: Heard '%s%s%s', (t = %d ms)\n " , __func__, " \033 [1m" , txt.c_str (), " \033 [0m" , (int ) t_ms);
601
+ const float p = 100 .0f * std::exp (logprob_min0);
602
+
603
+ fprintf (stdout, " %s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n " , __func__, " \033 [1m" , txt.c_str (), " \033 [0m" , (int ) t_ms, p);
542
604
543
605
const float sim = similarity (txt, k_prompt);
544
606
@@ -559,19 +621,30 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
559
621
// we have heard the activation phrase, now detect the commands
560
622
audio.get (params.command_ms , pcmf32_cur);
561
623
624
+ // printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
625
+ // printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
626
+
627
+ // prepend 3 second of silence
628
+ pcmf32_cur.insert (pcmf32_cur.begin (), 3 .0f *WHISPER_SAMPLE_RATE, 0 .0f );
629
+
562
630
// prepend the prompt audio
563
631
pcmf32_cur.insert (pcmf32_cur.begin (), pcmf32_prompt.begin (), pcmf32_prompt.end ());
564
632
565
- const auto txt = ::trim (::transcribe (ctx, params, pcmf32_cur, prob , t_ms));
633
+ const auto txt = ::trim (::transcribe (ctx, params, pcmf32_cur, " root " , logprob_min, logprob_sum, n_tokens , t_ms));
566
634
567
- prob = 100 .0f *(prob - prob0);
635
+ // const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
636
+ const float p = 100 .0f * std::exp (logprob_min);
568
637
569
638
// fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
570
639
571
640
// find the prompt in the text
572
641
float best_sim = 0 .0f ;
573
642
size_t best_len = 0 ;
574
- for (int n = 0.8 *k_prompt.size (); n <= 1.2 *k_prompt.size (); ++n) {
643
+ for (size_t n = 0.8 *k_prompt.size (); n <= 1.2 *k_prompt.size (); ++n) {
644
+ if (n >= txt.size ()) {
645
+ break ;
646
+ }
647
+
575
648
const auto prompt = txt.substr (0 , n);
576
649
577
650
const float sim = similarity (prompt, k_prompt);
@@ -584,9 +657,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
584
657
}
585
658
}
586
659
587
- const std::string command = ::trim (txt.substr (best_len));
660
+ fprintf (stdout, " %s: DEBUG: txt = '%s', prob = %.2f%%\n " , __func__, txt.c_str (), p);
661
+ if (best_len == 0 ) {
662
+ fprintf (stdout, " %s: WARNING: command not recognized, try again\n " , __func__);
663
+ } else {
664
+ // cut the prompt from the decoded text
665
+ const std::string command = ::trim (txt.substr (best_len));
666
+
667
+ fprintf (stdout, " %s: Command '%s%s%s', (t = %d ms)\n " , __func__, " \033 [1m" , command.c_str (), " \033 [0m" , (int ) t_ms);
668
+ }
588
669
589
- fprintf (stdout, " %s: Command '%s%s%s', (t = %d ms)\n " , __func__, " \033 [1m" , command.c_str (), " \033 [0m" , (int ) t_ms);
590
670
fprintf (stdout, " \n " );
591
671
}
592
672
@@ -654,12 +734,36 @@ int main(int argc, char ** argv) {
654
734
655
735
int ret_val = 0 ;
656
736
657
- if (!params.commands .empty ()) {
658
- ret_val = process_command_list (ctx, audio, params);
659
- } else if (!params.prompt .empty ()) {
660
- ret_val = always_prompt_transcription (ctx, audio, params);
661
- } else {
662
- ret_val = process_general_transcription (ctx, audio, params);
737
+ if (!params.grammar .empty ()) {
738
+ auto & grammar = params.grammar_parsed ;
739
+ if (file_exists (params.grammar .c_str ())) {
740
+ // read grammar from file
741
+ std::ifstream ifs (params.grammar .c_str ());
742
+ const std::string txt = std::string ((std::istreambuf_iterator<char >(ifs)), std::istreambuf_iterator<char >());
743
+ grammar = grammar_parser::parse (txt.c_str ());
744
+ } else {
745
+ // read grammar from string
746
+ grammar = grammar_parser::parse (params.grammar .c_str ());
747
+ }
748
+
749
+ // will be empty (default) if there are parse errors
750
+ if (grammar.rules .empty ()) {
751
+ ret_val = 1 ;
752
+ } else {
753
+ fprintf (stderr, " %s: grammar:\n " , __func__);
754
+ grammar_parser::print_grammar (stderr, grammar);
755
+ fprintf (stderr, " \n " );
756
+ }
757
+ }
758
+
759
+ if (ret_val == 0 ) {
760
+ if (!params.commands .empty ()) {
761
+ ret_val = process_command_list (ctx, audio, params);
762
+ } else if (!params.prompt .empty () && params.grammar_parsed .rules .empty ()) {
763
+ ret_val = always_prompt_transcription (ctx, audio, params);
764
+ } else {
765
+ ret_val = process_general_transcription (ctx, audio, params);
766
+ }
663
767
}
664
768
665
769
audio.pause ();
0 commit comments