Skip to content

Commit

Permalink
added rudimentary support for outetts v0.3 500m and 1b models
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Jan 18, 2025
1 parent 6390a99 commit b548695
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
}

// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
static std::string process_text(const std::string & text) {
static std::string process_text(const std::string & text, bool is_version_0_3) {

// For now I skipped text romanization as I am unsure how to handle
// uroman and MeCab implementations in C++
Expand Down Expand Up @@ -401,7 +401,7 @@ static std::string process_text(const std::string & text) {
if (c == ' ') {
prompt_clean += "<|text_sep|>";
*/
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), is_version_0_3?"<|space|>":"<|text_sep|>");

return processed_text;
}
Expand All @@ -425,8 +425,7 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
}

static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
const std::string& delimiter = "<|text_sep|>";
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const std::string& delimiter) {

std::vector<llama_token> result;
size_t start = 0;
Expand Down Expand Up @@ -523,6 +522,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> codes;
std::vector<llama_token> guide_tokens;

//determine OuteTTS version and vocab code offset. v0.2 does not have <|space|>, but v0.3 does
const bool is_version_0_3 = common_tokenize(vocab,"<|space|>",false,true).size()==1;
//determine the offset of the first audio code token
const int cts_offset = common_tokenize(vocab,"<|0|>",false,true)[0];

// process prompt and generate voice codes
{
LOG_INF("%s: constructing prompt ..\n", __func__);
Expand All @@ -531,13 +535,17 @@ int main(int argc, char ** argv) {

prompt_init(prompt_inp, vocab);

prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
if (is_version_0_3) {
prompt_add(prompt_inp, vocab, "<|text_start|>the<|space|>overall<|space|>package<|space|>from<|space|>just<|space|>two<|space|>people<|space|>is<|space|>pretty<|space|>remarkable<|space|>sure<|space|>i<|space|>have<|space|>some<|space|>critiques<|space|>about<|space|>some<|space|>of<|space|>the<|space|>gameplay<|space|>aspects<|space|>but<|space|>its<|space|>still<|space|>really<|space|>enjoyable<|space|>and<|space|>it<|space|>looks<|space|>lovely<|space|>", false, true);
} else {
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
}

// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt);
std::string prompt_clean = process_text(params.prompt, is_version_0_3);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, is_version_0_3?"<|space|>":"<|text_sep|>");
}

LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
Expand All @@ -549,8 +557,8 @@ int main(int argc, char ** argv) {

// disabled to save time on tokenizing each time
// TODO: load voices from the json files
#if 0
const std::string voice_data = R"(<|audio_start|>
#if 1
std::string voice_data = R"(<|audio_start|>
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
Expand Down Expand Up @@ -582,12 +590,19 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";

auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n");
for (int i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]);
if (is_version_0_3)
{
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_start\|>)"), "");
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
}
printf("\n\n");

prompt_add(prompt_inp, vocab, voice_data, false, true);

// printf("\n\n");
// for (int i = 0; i < tmp.size(); ++i) {
// printf("%d, ", tmp[i]);
// }
// printf("\n\n");
#else
prompt_add(prompt_inp, llama_tokens {
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
Expand Down Expand Up @@ -882,7 +897,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
}

// remove all non-audio tokens (i.e. < 151672 || > 155772)
codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end());
codes.erase(std::remove_if(codes.begin(), codes.end(), [cts_offset](llama_token t) { return t < cts_offset || t > (cts_offset+4100); }), codes.end());

{
const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
Expand All @@ -891,7 +906,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
}

for (auto & token : codes) {
token -= 151672;
token -= cts_offset;
}

const auto t_voc_start = ggml_time_us();
Expand Down

0 comments on commit b548695

Please sign in to comment.