Skip to content

Commit 3e5c7fe

Browse files
ejonesggerganov
andauthoredNov 13, 2023
whisper : add grammar-based sampling (ggerganov#1229)
* whisper : add grammar-based sampling * build : fix after master merge * command : fix exception when recognizing the command * whisper : fine-tuning grammar functionality * command : grammar-related improvements - option to read grammar from file - add sample grammars for colors and chess moves - fine-tune the performance further * grammars : add assistant + update comments * command : enable beam-search, add "no_timestamps", add "context", add p * whisper : remove comment --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent c23598e commit 3e5c7fe

10 files changed

+1289
-69
lines changed
 

‎Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,8 @@ quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
362362
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
363363
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
364364

365-
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
366-
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
365+
command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
366+
$(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
367367

368368
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
369369
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)

‎examples/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ add_library(${TARGET} STATIC
2323
common.cpp
2424
common-ggml.h
2525
common-ggml.cpp
26+
grammar-parser.cpp
2627
)
2728

2829
include(DefaultTargetOptions)

‎examples/command/command.cpp

+140-36
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "common-sdl.h"
1010
#include "common.h"
1111
#include "whisper.h"
12+
#include "grammar-parser.h"
1213

1314
#include <sstream>
1415
#include <cassert>
@@ -21,6 +22,11 @@
2122
#include <vector>
2223
#include <map>
2324

25+
bool file_exists(const std::string & fname) {
26+
std::ifstream f(fname.c_str());
27+
return f.good();
28+
}
29+
2430
// command-line parameters
2531
struct whisper_params {
2632
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@@ -30,8 +36,12 @@ struct whisper_params {
3036
int32_t max_tokens = 32;
3137
int32_t audio_ctx = 0;
3238

33-
float vad_thold = 0.6f;
34-
float freq_thold = 100.0f;
39+
float vad_thold = 0.6f;
40+
float freq_thold = 100.0f;
41+
42+
float grammar_penalty = 100.0f;
43+
44+
grammar_parser::parse_state grammar_parsed;
3545

3646
bool speed_up = false;
3747
bool translate = false;
@@ -45,6 +55,8 @@ struct whisper_params {
4555
std::string fname_out;
4656
std::string commands;
4757
std::string prompt;
58+
std::string context;
59+
std::string grammar;
4860
};
4961

5062
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -75,6 +87,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
7587
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
7688
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
7789
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
90+
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
91+
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
92+
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
7893
else {
7994
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
8095
whisper_print_usage(argc, argv, params);
@@ -109,36 +124,72 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
109124
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
110125
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
111126
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
127+
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
128+
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
129+
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
112130
fprintf(stderr, "\n");
113131
}
114132

115-
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
133+
std::string transcribe(
134+
whisper_context * ctx,
135+
const whisper_params & params,
136+
const std::vector<float> & pcmf32,
137+
const std::string & grammar_rule,
138+
float & logprob_min,
139+
float & logprob_sum,
140+
int & n_tokens,
141+
int64_t & t_ms) {
116142
const auto t_start = std::chrono::high_resolution_clock::now();
117143

118-
prob = 0.0f;
144+
logprob_min = 0.0f;
145+
logprob_sum = 0.0f;
146+
n_tokens = 0;
119147
t_ms = 0;
120148

121-
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
149+
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
150+
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
122151

123152
wparams.print_progress = false;
124153
wparams.print_special = params.print_special;
125154
wparams.print_realtime = false;
126155
wparams.print_timestamps = !params.no_timestamps;
127156
wparams.translate = params.translate;
128157
wparams.no_context = true;
158+
wparams.no_timestamps = params.no_timestamps;
129159
wparams.single_segment = true;
130160
wparams.max_tokens = params.max_tokens;
131161
wparams.language = params.language.c_str();
132162
wparams.n_threads = params.n_threads;
133163

134-
wparams.audio_ctx = params.audio_ctx;
135-
wparams.speed_up = params.speed_up;
164+
wparams.audio_ctx = params.audio_ctx;
165+
wparams.speed_up = params.speed_up;
166+
167+
wparams.temperature = 0.4f;
168+
wparams.temperature_inc = 1.0f;
169+
wparams.greedy.best_of = 5;
170+
171+
wparams.beam_search.beam_size = 5;
172+
173+
wparams.initial_prompt = params.context.data();
174+
175+
const auto & grammar_parsed = params.grammar_parsed;
176+
auto grammar_rules = grammar_parsed.c_rules();
177+
178+
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
179+
if (grammar_parsed.symbol_ids.find(grammar_rule) == grammar_parsed.symbol_ids.end()) {
180+
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, grammar_rule.c_str());
181+
} else {
182+
wparams.grammar_rules = grammar_rules.data();
183+
wparams.n_grammar_rules = grammar_rules.size();
184+
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
185+
wparams.grammar_penalty = params.grammar_penalty;
186+
}
187+
}
136188

137189
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
138190
return "";
139191
}
140192

141-
int prob_n = 0;
142193
std::string result;
143194

144195
const int n_segments = whisper_full_n_segments(ctx);
@@ -147,19 +198,17 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
147198

148199
result += text;
149200

150-
const int n_tokens = whisper_full_n_tokens(ctx, i);
151-
for (int j = 0; j < n_tokens; ++j) {
201+
const int n = whisper_full_n_tokens(ctx, i);
202+
for (int j = 0; j < n; ++j) {
152203
const auto token = whisper_full_get_token_data(ctx, i, j);
153204

154-
prob += token.p;
155-
++prob_n;
205+
if(token.plog > 0.0f) exit(0);
206+
logprob_min = std::min(logprob_min, token.plog);
207+
logprob_sum += token.plog;
208+
++n_tokens;
156209
}
157210
}
158211

159-
if (prob_n > 0) {
160-
prob /= prob_n;
161-
}
162-
163212
const auto t_end = std::chrono::high_resolution_clock::now();
164213
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
165214

@@ -250,7 +299,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
250299
fprintf(stderr, " ]\n");
251300
}
252301

253-
std::string k_prompt = "select one from the available words: ";
302+
std::string k_prompt = "select one from the available words: ";
254303
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
255304
if (i > 0) {
256305
k_prompt += ", ";
@@ -418,7 +467,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
418467
bool is_running = true;
419468
bool ask_prompt = true;
420469

421-
float prob = 0.0f;
470+
float logprob_min = 0.0f;
471+
float logprob_sum = 0.0f;
472+
int n_tokens = 0;
422473

423474
std::vector<float> pcmf32_cur;
424475

@@ -456,7 +507,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
456507
// detect the commands
457508
audio.get(params.command_ms, pcmf32_cur);
458509

459-
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
510+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
460511

461512
const auto words = get_words(txt);
462513

@@ -492,18 +543,27 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
492543

493544
// general-purpose mode
494545
// freely transcribe the voice into text
495-
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
546+
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
496547
bool is_running = true;
497548
bool have_prompt = false;
498549
bool ask_prompt = true;
499550

500-
float prob0 = 0.0f;
501-
float prob = 0.0f;
551+
float logprob_min0 = 0.0f;
552+
float logprob_min = 0.0f;
553+
554+
float logprob_sum0 = 0.0f;
555+
float logprob_sum = 0.0f;
556+
557+
int n_tokens0 = 0;
558+
int n_tokens = 0;
502559

503560
std::vector<float> pcmf32_cur;
504561
std::vector<float> pcmf32_prompt;
505562

506-
const std::string k_prompt = "Ok Whisper, start listening for commands.";
563+
std::string k_prompt = "Ok Whisper, start listening for commands.";
564+
if (!params.prompt.empty()) {
565+
k_prompt = params.prompt;
566+
}
507567

508568
fprintf(stderr, "\n");
509569
fprintf(stderr, "%s: general-purpose mode\n", __func__);
@@ -536,9 +596,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
536596
// wait for activation phrase
537597
audio.get(params.prompt_ms, pcmf32_cur);
538598

539-
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
599+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
540600

541-
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
601+
const float p = 100.0f * std::exp(logprob_min0);
602+
603+
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
542604

543605
const float sim = similarity(txt, k_prompt);
544606

@@ -559,19 +621,30 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
559621
// we have heard the activation phrase, now detect the commands
560622
audio.get(params.command_ms, pcmf32_cur);
561623

624+
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
625+
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
626+
627+
// prepend 3 second of silence
628+
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
629+
562630
// prepend the prompt audio
563631
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
564632

565-
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
633+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
566634

567-
prob = 100.0f*(prob - prob0);
635+
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
636+
const float p = 100.0f * std::exp(logprob_min);
568637

569638
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
570639

571640
// find the prompt in the text
572641
float best_sim = 0.0f;
573642
size_t best_len = 0;
574-
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
643+
for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
644+
if (n >= txt.size()) {
645+
break;
646+
}
647+
575648
const auto prompt = txt.substr(0, n);
576649

577650
const float sim = similarity(prompt, k_prompt);
@@ -584,9 +657,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
584657
}
585658
}
586659

587-
const std::string command = ::trim(txt.substr(best_len));
660+
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
661+
if (best_len == 0) {
662+
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
663+
} else {
664+
// cut the prompt from the decoded text
665+
const std::string command = ::trim(txt.substr(best_len));
666+
667+
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
668+
}
588669

589-
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
590670
fprintf(stdout, "\n");
591671
}
592672

@@ -654,12 +734,36 @@ int main(int argc, char ** argv) {
654734

655735
int ret_val = 0;
656736

657-
if (!params.commands.empty()) {
658-
ret_val = process_command_list(ctx, audio, params);
659-
} else if (!params.prompt.empty()) {
660-
ret_val = always_prompt_transcription(ctx, audio, params);
661-
} else {
662-
ret_val = process_general_transcription(ctx, audio, params);
737+
if (!params.grammar.empty()) {
738+
auto & grammar = params.grammar_parsed;
739+
if (file_exists(params.grammar.c_str())) {
740+
// read grammar from file
741+
std::ifstream ifs(params.grammar.c_str());
742+
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
743+
grammar = grammar_parser::parse(txt.c_str());
744+
} else {
745+
// read grammar from string
746+
grammar = grammar_parser::parse(params.grammar.c_str());
747+
}
748+
749+
// will be empty (default) if there are parse errors
750+
if (grammar.rules.empty()) {
751+
ret_val = 1;
752+
} else {
753+
fprintf(stderr, "%s: grammar:\n", __func__);
754+
grammar_parser::print_grammar(stderr, grammar);
755+
fprintf(stderr, "\n");
756+
}
757+
}
758+
759+
if (ret_val == 0) {
760+
if (!params.commands.empty()) {
761+
ret_val = process_command_list(ctx, audio, params);
762+
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
763+
ret_val = always_prompt_transcription(ctx, audio, params);
764+
} else {
765+
ret_val = process_general_transcription(ctx, audio, params);
766+
}
663767
}
664768

665769
audio.pause();

‎examples/grammar-parser.cpp

+423
Large diffs are not rendered by default.

‎examples/grammar-parser.h

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Implements a parser for an extended Backus-Naur form (BNF), producing the
2+
// binary context-free grammar format specified by whisper.h. Supports character
3+
// ranges, grouping, and repetition operators. As an example, a grammar for
4+
// arithmetic might look like:
5+
//
6+
// root ::= expr
7+
// expr ::= term ([-+*/] term)*
8+
// term ::= num | "(" space expr ")" space
9+
// num ::= [0-9]+ space
10+
// space ::= [ \t\n]*
11+
12+
#pragma once
13+
#include "whisper.h"
14+
#include <vector>
15+
#include <map>
16+
#include <cstdint>
17+
#include <string>
18+
19+
namespace grammar_parser {
20+
struct parse_state {
21+
std::map<std::string, uint32_t> symbol_ids;
22+
std::vector<std::vector<whisper_grammar_element>> rules;
23+
24+
std::vector<const whisper_grammar_element *> c_rules() const;
25+
};
26+
27+
parse_state parse(const char * src);
28+
void print_grammar(FILE * file, const parse_state & state);
29+
}

‎grammars/assistant.gbnf

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# - "turn on lights."
2+
# - "set thermostat to 22."
3+
# - "increase TV by 10."
4+
# - "decrease oven by 50."
5+
# - "play music."
6+
# - "stop podcast."
7+
# - "schedule cleaning at 3pm."
8+
# - "cancel cleaning."
9+
# - "remind me to buy milk at 5pm."
10+
# - "show me security system."
11+
# - "hide washing machine."
12+
# - "what is the lights status?"
13+
# - "what is the current thermostat value?"
14+
# - "what is the security system status?"
15+
# - "what is the door lock status?"
16+
# - "what is the camera battery level?"
17+
# - "what is the weather like today?"
18+
# - "what is the forecast for tomorrow?"
19+
# - "what is the time?"
20+
# - "what is my schedule for today?"
21+
# - "what tasks do I have?"
22+
# - "what reminders do I have?"
23+
#
24+
# example:
25+
#
26+
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
27+
#
28+
29+
root ::= init " " (command | question) "."
30+
prompt ::= init
31+
32+
# leading space is very important!
33+
init ::= " Ok Whisper, start listening for commands."
34+
35+
command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
36+
"Increase " device " by " value | "Decrease " device " by " value |
37+
"Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
38+
"Remind me to " task " at " time | "Show me " device | "Hide " device
39+
40+
question ::= "What is the " device " status?" | "What is the current " device " value?" |
41+
"What is the " device " temperature?" | "What is the " device " humidity?" |
42+
"What is the " device " power consumption?" | "What is the " device " battery level?" |
43+
"What is the weather like today?" | "What is the forecast for tomorrow?" |
44+
"What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
45+
"What reminders do I have?"
46+
47+
device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
48+
"music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
49+
"vacuum cleaner"
50+
51+
value ::= [0-9]+
52+
53+
media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"
54+
55+
task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
56+
57+
time ::= [0-9] [0-9]? ("am" | "pm")?

‎grammars/chess.gbnf

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# - bishop to c3
2+
# - rook to d4
3+
# - knight to e5
4+
# - d4 d5 knight to c3
5+
# - c3 queen to d4 king b1
6+
# - pawn to a1 bishop to b2 knight to c3
7+
#
8+
# The prompt (--prompt) is the initial phrase that the user has to say.
9+
# This is used to prime Whisper with how the user is expected to speak.
10+
#
11+
# Provide long context (--context) with sample moves to help Whisper decode the correct sequence.
12+
# Longer context is better, but it slightly increases the processing time.
13+
#
14+
# example:
15+
#
16+
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100
17+
#
18+
19+
root ::= init move move? move? "."
20+
prompt ::= init "."
21+
22+
# leading space is very important!
23+
init ::= " rook to b4, f3"
24+
25+
move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
26+
27+
piece ::= "bishop" | "rook" | "knight" | "queen"
28+
king ::= "king"
29+
pawn ::= "pawn"

‎grammars/colors.gbnf

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# - red
2+
# - green
3+
# - blue
4+
#
5+
# example:
6+
#
7+
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue,"
8+
#
9+
10+
root ::= init color "."
11+
prompt ::= init "."
12+
13+
# leading space is very important!
14+
init ::= " red, green, blue"
15+
16+
color ::= ", " ("red" | "green" | "blue")

‎whisper.cpp

+555-31
Large diffs are not rendered by default.

‎whisper.h

+37
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,37 @@ extern "C" {
109109
void (*close)(void * ctx);
110110
} whisper_model_loader;
111111

112+
// grammar element type
113+
enum whisper_gretype {
114+
// end of rule definition
115+
WHISPER_GRETYPE_END = 0,
116+
117+
// start of alternate definition for rule
118+
WHISPER_GRETYPE_ALT = 1,
119+
120+
// non-terminal element: reference to rule
121+
WHISPER_GRETYPE_RULE_REF = 2,
122+
123+
// terminal element: character (code point)
124+
WHISPER_GRETYPE_CHAR = 3,
125+
126+
// inverse char(s) ([^a], [^a-b] [^abc])
127+
WHISPER_GRETYPE_CHAR_NOT = 4,
128+
129+
// modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
130+
// be an inclusive range ([a-z])
131+
WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
132+
133+
// modifies a preceding WHISPER_GRETYPE_CHAR or
134+
// WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
135+
WHISPER_GRETYPE_CHAR_ALT = 6,
136+
};
137+
138+
typedef struct whisper_grammar_element {
139+
enum whisper_gretype type;
140+
uint32_t value; // Unicode code point or rule ID
141+
} whisper_grammar_element;
142+
112143
// Various functions for loading a ggml whisper model.
113144
// Allocate (almost) all memory needed for the model.
114145
// Return NULL on failure
@@ -402,6 +433,7 @@ extern "C" {
402433

403434
bool translate;
404435
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
436+
bool no_timestamps; // do not generate timestamps
405437
bool single_segment; // force single segment output (useful for streaming)
406438
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
407439
bool print_progress; // print progress information
@@ -479,6 +511,11 @@ extern "C" {
479511
// called by each decoder to filter obtained logits
480512
whisper_logits_filter_callback logits_filter_callback;
481513
void * logits_filter_callback_user_data;
514+
515+
const whisper_grammar_element ** grammar_rules;
516+
size_t n_grammar_rules;
517+
size_t i_start_rule;
518+
float grammar_penalty;
482519
};
483520

484521
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()

0 commit comments

Comments
 (0)
Please sign in to comment.