From 204ba96b76098c281e0e73385cb2e61aab89c800 Mon Sep 17 00:00:00 2001 From: Lyrcaxis Date: Wed, 26 Feb 2025 11:55:01 +0200 Subject: [PATCH 1/5] Made Vocabulary properties be initialized only ONCE on creation --- LLama/Native/SafeLlamaModelHandle.cs | 238 +++++++-------------------- 1 file changed, 56 insertions(+), 182 deletions(-) diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 9439c2bb..619a3afc 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.Linq; using System.Text; using LLama.Exceptions; @@ -631,34 +632,51 @@ public sealed class Vocabulary internal unsafe LLamaVocabNative* VocabNative => llama_model_get_vocab(_model); + /// + /// Cache of all the tokens in the vocabulary, and their string representation + /// + public readonly IReadOnlyDictionary TokenToString; + internal Vocabulary(SafeLlamaModelHandle model) { _model = model; - } - - private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken) - { - if (!token.HasValue) - return null; + TokenToString = GetVocabCache(); - // Try to convert using a fixed size buffer - const int buffSize = 32; - Span buff = stackalloc byte[buffSize]; - var tokenLength = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken); - - // Negative indicates that there was no result - if (tokenLength <= 0) - return null; - - // if the original buffer wasn't large enough, try again with one that's the right size - if (tokenLength > buffSize) + unsafe { - buff = stackalloc byte[(int)tokenLength]; - _ = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken); + var vocabNative = llama_model_get_vocab(_model); + Count = LLamaVocabNative.llama_vocab_n_tokens(vocabNative); + Type = LLamaVocabNative.llama_vocab_type(vocabNative); + BOS = Normalize(LLamaVocabNative.llama_vocab_bos(vocabNative)); + EOS = Normalize(LLamaVocabNative.llama_vocab_eos(vocabNative)); + Newline = Normalize(LLamaVocabNative.llama_vocab_nl(vocabNative)); + Pad = Normalize(LLamaVocabNative.llama_vocab_pad(vocabNative)); + SEP = Normalize(LLamaVocabNative.llama_vocab_sep(vocabNative)); + InfillPrefix = Normalize(LLamaVocabNative.llama_vocab_fim_pre(vocabNative)); + InfillMiddle = Normalize(LLamaVocabNative.llama_vocab_fim_mid(vocabNative)); + InfillSuffix = Normalize(LLamaVocabNative.llama_vocab_fim_suf(vocabNative)); + InfillPad = Normalize(LLamaVocabNative.llama_vocab_fim_pad(vocabNative)); + InfillRep = Normalize(LLamaVocabNative.llama_vocab_fim_rep(vocabNative)); + InfillSep = Normalize(LLamaVocabNative.llama_vocab_fim_sep(vocabNative)); + EOT = Normalize(LLamaVocabNative.llama_vocab_eot(vocabNative)); + DecoderStartToken = Normalize(llama_model_decoder_start_token(_model)); + ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative); + ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative); } + } - var slice = buff.Slice(0, (int)tokenLength); - return Encoding.UTF8.GetStringFromSpan(slice); + private Dictionary GetVocabCache() + { + var decoder = Encoding.UTF8.GetDecoder(); + var (bytesArr, charsArr) = (new byte[1024], new char[1024]); + return Enumerable.Range(0, Count).ToDictionary( + keySelector: i => (LLamaToken) i, + elementSelector: i => + { + decoder.Convert(bytesArr, 0, (int) _model.TokenToSpan(i, bytesArr), charsArr, 0, charsArr.Length, true, out var _, out var charsUsed, out var _); + return string.Join("", charsArr.Take(charsUsed)); + } + ); } private static LLamaToken? Normalize(LLamaToken token) @@ -669,232 +687,88 @@ internal Vocabulary(SafeLlamaModelHandle model) /// /// Total number of tokens in this vocabulary /// - public int Count - { - get - { - unsafe - { - return LLamaVocabNative.llama_vocab_n_tokens(VocabNative); - } - } - } + public int Count { get; init; } /// /// Get the the type of this vocabulary /// - public LLamaVocabType Type - { - get - { - unsafe - { - return LLamaVocabNative.llama_vocab_type(VocabNative); - } - } - } + public LLamaVocabType Type { get; init; } /// /// Get the Beginning of Sentence token for this model /// - public LLamaToken? BOS - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_bos(VocabNative)); - } - } - } + public LLamaToken? BOS { get; init; } /// /// Get the End of Sentence token for this model /// - public LLamaToken? EOS - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_eos(VocabNative)); - } - } - } + public LLamaToken? EOS { get; init; } /// /// Get the newline token for this model /// - public LLamaToken? Newline - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_nl(VocabNative)); - } - } - } + public LLamaToken? Newline { get; init; } /// /// Get the padding token for this model /// - public LLamaToken? Pad - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_pad(VocabNative)); - } - } - } + public LLamaToken? Pad { get; init; } /// /// Get the sentence separator token for this model /// - public LLamaToken? SEP - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_sep(VocabNative)); - } - } - } + public LLamaToken? SEP { get; init; } /// /// Codellama beginning of infill prefix /// - public LLamaToken? InfillPrefix - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_pre(VocabNative)); - } - } - } + public LLamaToken? InfillPrefix { get; init; } /// /// Codellama beginning of infill middle /// - public LLamaToken? InfillMiddle - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_mid(VocabNative)); - } - } - } + public LLamaToken? InfillMiddle { get; init; } /// /// Codellama beginning of infill suffix /// - public LLamaToken? InfillSuffix - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_suf(VocabNative)); - } - } - } + public LLamaToken? InfillSuffix { get; init; } /// /// Codellama pad /// - public LLamaToken? InfillPad - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_pad(VocabNative)); - } - } - } + public LLamaToken? InfillPad { get; init; } /// /// Codellama rep /// - public LLamaToken? InfillRep - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_rep(VocabNative)); - } - } - } + public LLamaToken? InfillRep { get; init; } /// /// Codellama rep /// - public LLamaToken? InfillSep - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_sep(VocabNative)); - } - } - } + public LLamaToken? InfillSep { get; init; } /// /// end-of-turn token /// - public LLamaToken? EOT - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_eot(VocabNative)); - } - } - } + public LLamaToken? EOT { get; init; } /// /// For encoder-decoder models, this function returns id of the token that must be provided /// to the decoder to start generating output sequence. /// - public LLamaToken? DecoderStartToken => Normalize(llama_model_decoder_start_token(_model)); + public LLamaToken? DecoderStartToken { get; init; } /// /// Check if the current model requires a BOS token added /// - public bool ShouldAddBOS - { - get - { - unsafe - { - return LLamaVocabNative.llama_vocab_get_add_bos(llama_model_get_vocab(_model)); - } - } - } + public bool ShouldAddBOS { get; init; } /// /// Check if the current model requires a EOS token added /// - public bool ShouldAddEOS - { - get - { - unsafe - { - return LLamaVocabNative.llama_vocab_get_add_eos(llama_model_get_vocab(_model)); - } - } - } + public bool ShouldAddEOS { get; init; } } } } From 1df95685e9e41fd0b04f8de7419a2b979248de86 Mon Sep 17 00:00:00 2001 From: Lyrcaxis Date: Wed, 26 Feb 2025 12:25:20 +0200 Subject: [PATCH 2/5] Added cache for EOG and Control tokens to Vocabulary --- LLama/Native/LLamaToken.cs | 10 ++-------- LLama/Native/SafeLlamaModelHandle.cs | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index dd45ebbe..5aeac5ca 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -98,10 +98,7 @@ public bool IsControl(SafeLlamaModelHandle model) /// public bool IsControl(SafeLlamaModelHandle.Vocabulary vocab) { - unsafe - { - return LLamaVocabNative.llama_vocab_is_control(vocab.VocabNative, this); - } + return vocab.ControlTokens.Contains(this); } /// @@ -121,10 +118,7 @@ public bool IsEndOfGeneration(SafeLlamaModelHandle model) /// public bool IsEndOfGeneration(SafeLlamaModelHandle.Vocabulary vocab) { - unsafe - { - return LLamaVocabNative.llama_vocab_is_eog(vocab.VocabNative, this); - } + return vocab.EOGTokens.Contains(this); } /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 619a3afc..cdcc35b5 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -633,15 +633,26 @@ public sealed class Vocabulary internal unsafe LLamaVocabNative* VocabNative => llama_model_get_vocab(_model); /// - /// Cache of all the tokens in the vocabulary, and their string representation + /// Map of each token in this vocabulary to its string representation /// public readonly IReadOnlyDictionary TokenToString; + /// + /// Contains unique tokens that are supposed to end the generation (e.g.: EOS, EOT, etc) + /// + public readonly HashSet EOGTokens; + + /// + /// Contains unique tokens that exist for inference control rather than text output + /// + public readonly HashSet ControlTokens; + internal Vocabulary(SafeLlamaModelHandle model) { _model = model; TokenToString = GetVocabCache(); + // Cache the various properties that llama.cpp API exposes about the vocab unsafe { var vocabNative = llama_model_get_vocab(_model); @@ -662,6 +673,9 @@ internal Vocabulary(SafeLlamaModelHandle model) DecoderStartToken = Normalize(llama_model_decoder_start_token(_model)); ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative); ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative); + + EOGTokens = new HashSet(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token))); + ControlTokens = new HashSet(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token))); } } From 8b2b7cca62d525a20d3e4f55fb6ff75aa713b696 Mon Sep 17 00:00:00 2001 From: Lyrcaxis Date: Wed, 26 Feb 2025 17:03:07 +0200 Subject: [PATCH 3/5] Addressed change requests --- LLama/Native/LLamaToken.cs | 1 + LLama/Native/SafeLlamaModelHandle.cs | 44 ++++++++++++++-------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index 5aeac5ca..2b738ce8 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -1,4 +1,5 @@ using System.Diagnostics; +using System.Linq; namespace LLama.Native; diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index cdcc35b5..d68274da 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -635,17 +635,17 @@ public sealed class Vocabulary /// /// Map of each token in this vocabulary to its string representation /// - public readonly IReadOnlyDictionary TokenToString; + internal readonly IReadOnlyDictionary TokenToString; /// /// Contains unique tokens that are supposed to end the generation (e.g.: EOS, EOT, etc) /// - public readonly HashSet EOGTokens; + internal readonly IReadOnlyList EOGTokens; /// /// Contains unique tokens that exist for inference control rather than text output /// - public readonly HashSet ControlTokens; + internal readonly IReadOnlyList ControlTokens; internal Vocabulary(SafeLlamaModelHandle model) { @@ -674,8 +674,8 @@ internal Vocabulary(SafeLlamaModelHandle model) ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative); ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative); - EOGTokens = new HashSet(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token))); - ControlTokens = new HashSet(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token))); + EOGTokens = TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token)).ToList(); + ControlTokens = TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token)).ToList(); } } @@ -701,88 +701,88 @@ private Dictionary GetVocabCache() /// /// Total number of tokens in this vocabulary /// - public int Count { get; init; } + public int Count { get; } /// /// Get the the type of this vocabulary /// - public LLamaVocabType Type { get; init; } + public LLamaVocabType Type { get; } /// /// Get the Beginning of Sentence token for this model /// - public LLamaToken? BOS { get; init; } + public LLamaToken? BOS { get; } /// /// Get the End of Sentence token for this model /// - public LLamaToken? EOS { get; init; } + public LLamaToken? EOS { get; } /// /// Get the newline token for this model /// - public LLamaToken? Newline { get; init; } + public LLamaToken? Newline { get; } /// /// Get the padding token for this model /// - public LLamaToken? Pad { get; init; } + public LLamaToken? Pad { get; } /// /// Get the sentence separator token for this model /// - public LLamaToken? SEP { get; init; } + public LLamaToken? SEP { get; } /// /// Codellama beginning of infill prefix /// - public LLamaToken? InfillPrefix { get; init; } + public LLamaToken? InfillPrefix { get; } /// /// Codellama beginning of infill middle /// - public LLamaToken? InfillMiddle { get; init; } + public LLamaToken? InfillMiddle { get; } /// /// Codellama beginning of infill suffix /// - public LLamaToken? InfillSuffix { get; init; } + public LLamaToken? InfillSuffix { get; } /// /// Codellama pad /// - public LLamaToken? InfillPad { get; init; } + public LLamaToken? InfillPad { get; } /// /// Codellama rep /// - public LLamaToken? InfillRep { get; init; } + public LLamaToken? InfillRep { get; } /// /// Codellama rep /// - public LLamaToken? InfillSep { get; init; } + public LLamaToken? InfillSep { get; } /// /// end-of-turn token /// - public LLamaToken? EOT { get; init; } + public LLamaToken? EOT { get; } /// /// For encoder-decoder models, this function returns id of the token that must be provided /// to the decoder to start generating output sequence. /// - public LLamaToken? DecoderStartToken { get; init; } + public LLamaToken? DecoderStartToken { get; } /// /// Check if the current model requires a BOS token added /// - public bool ShouldAddBOS { get; init; } + public bool ShouldAddBOS { get; } /// /// Check if the current model requires a EOS token added /// - public bool ShouldAddEOS { get; init; } + public bool ShouldAddEOS { get; } } } } From 8f2168581509cdcee136ebe506c49df1de504fd1 Mon Sep 17 00:00:00 2001 From: Lyrcaxis Date: Sun, 16 Mar 2025 18:20:00 +0200 Subject: [PATCH 4/5] Tweaks on Vocabulary --- LLama/Native/NativeApi.cs | 22 +++++++-- LLama/Native/SafeLlamaModelHandle.cs | 68 ++++++++++++++-------------- 2 files changed, 52 insertions(+), 38 deletions(-) diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 06e3baee..b862fc06 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -7,7 +7,7 @@ namespace LLama.Native /// /// Direct translation of the llama.cpp API /// - public static partial class NativeApi + public static partial class NativeApi { /// /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. @@ -202,15 +202,31 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* /// The length written, or if the buffer is too small a negative that indicates the length required public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span buffer, int lstrip, bool special) { + unsafe + { + return llama_token_to_piece(vocab.VocabNative, llamaToken, buffer, lstrip, special); + } + } + + /// + /// Convert a single token into text + /// + /// + /// + /// buffer to write string into + /// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') + /// If true, special tokens are rendered in the output + /// The length written, or if the buffer is too small a negative that indicates the length required + internal static unsafe int llama_token_to_piece(LLamaVocabNative* vocabNative, LLamaToken llamaToken, Span buffer, int lstrip, bool special) { // Handle invalid tokens - if ((int)llamaToken < 0) + if ((int) llamaToken < 0) return 0; unsafe { fixed (byte* bufferPtr = buffer) { - return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special); + return llama_token_to_piece_native(vocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special); } } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index d68274da..1897dd4b 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -635,62 +635,60 @@ public sealed class Vocabulary /// /// Map of each token in this vocabulary to its string representation /// - internal readonly IReadOnlyDictionary TokenToString; + public readonly IReadOnlyDictionary TokenToString; /// /// Contains unique tokens that are supposed to end the generation (e.g.: EOS, EOT, etc) /// - internal readonly IReadOnlyList EOGTokens; + internal readonly HashSet EOGTokens; /// /// Contains unique tokens that exist for inference control rather than text output /// - internal readonly IReadOnlyList ControlTokens; + internal readonly HashSet ControlTokens; - internal Vocabulary(SafeLlamaModelHandle model) + internal unsafe Vocabulary(SafeLlamaModelHandle model) { _model = model; - TokenToString = GetVocabCache(); // Cache the various properties that llama.cpp API exposes about the vocab - unsafe - { - var vocabNative = llama_model_get_vocab(_model); - Count = LLamaVocabNative.llama_vocab_n_tokens(vocabNative); - Type = LLamaVocabNative.llama_vocab_type(vocabNative); - BOS = Normalize(LLamaVocabNative.llama_vocab_bos(vocabNative)); - EOS = Normalize(LLamaVocabNative.llama_vocab_eos(vocabNative)); - Newline = Normalize(LLamaVocabNative.llama_vocab_nl(vocabNative)); - Pad = Normalize(LLamaVocabNative.llama_vocab_pad(vocabNative)); - SEP = Normalize(LLamaVocabNative.llama_vocab_sep(vocabNative)); - InfillPrefix = Normalize(LLamaVocabNative.llama_vocab_fim_pre(vocabNative)); - InfillMiddle = Normalize(LLamaVocabNative.llama_vocab_fim_mid(vocabNative)); - InfillSuffix = Normalize(LLamaVocabNative.llama_vocab_fim_suf(vocabNative)); - InfillPad = Normalize(LLamaVocabNative.llama_vocab_fim_pad(vocabNative)); - InfillRep = Normalize(LLamaVocabNative.llama_vocab_fim_rep(vocabNative)); - InfillSep = Normalize(LLamaVocabNative.llama_vocab_fim_sep(vocabNative)); - EOT = Normalize(LLamaVocabNative.llama_vocab_eot(vocabNative)); - DecoderStartToken = Normalize(llama_model_decoder_start_token(_model)); - ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative); - ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative); - - EOGTokens = TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token)).ToList(); - ControlTokens = TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token)).ToList(); - } - } - - private Dictionary GetVocabCache() - { + var vocabNative = llama_model_get_vocab(_model); + Count = LLamaVocabNative.llama_vocab_n_tokens(vocabNative); + Type = LLamaVocabNative.llama_vocab_type(vocabNative); + + BOS = Normalize(LLamaVocabNative.llama_vocab_bos(vocabNative)); + EOS = Normalize(LLamaVocabNative.llama_vocab_eos(vocabNative)); + EOT = Normalize(LLamaVocabNative.llama_vocab_eot(vocabNative)); + Pad = Normalize(LLamaVocabNative.llama_vocab_pad(vocabNative)); + SEP = Normalize(LLamaVocabNative.llama_vocab_sep(vocabNative)); + Newline = Normalize(LLamaVocabNative.llama_vocab_nl(vocabNative)); + + InfillPrefix = Normalize(LLamaVocabNative.llama_vocab_fim_pre(vocabNative)); + InfillMiddle = Normalize(LLamaVocabNative.llama_vocab_fim_mid(vocabNative)); + InfillSuffix = Normalize(LLamaVocabNative.llama_vocab_fim_suf(vocabNative)); + InfillPad = Normalize(LLamaVocabNative.llama_vocab_fim_pad(vocabNative)); + InfillRep = Normalize(LLamaVocabNative.llama_vocab_fim_rep(vocabNative)); + InfillSep = Normalize(LLamaVocabNative.llama_vocab_fim_sep(vocabNative)); + + DecoderStartToken = Normalize(llama_model_decoder_start_token(_model)); + ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative); + ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative); + + // Cache `TokenToString` for quick access var decoder = Encoding.UTF8.GetDecoder(); var (bytesArr, charsArr) = (new byte[1024], new char[1024]); - return Enumerable.Range(0, Count).ToDictionary( + TokenToString = Enumerable.Range(0, Count).ToDictionary( keySelector: i => (LLamaToken) i, elementSelector: i => { - decoder.Convert(bytesArr, 0, (int) _model.TokenToSpan(i, bytesArr), charsArr, 0, charsArr.Length, true, out var _, out var charsUsed, out var _); + var length = NativeApi.llama_token_to_piece(vocabNative, (LLamaToken) i, bytesArr, 0, true); + decoder.Convert(bytesArr, 0, length, charsArr, 0, charsArr.Length, true, out var _, out var charsUsed, out var _); return string.Join("", charsArr.Take(charsUsed)); } ); + + EOGTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token))); + ControlTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token))); } private static LLamaToken? Normalize(LLamaToken token) From 513c1e3dce9cf19192ab7b846ceecf202311d95a Mon Sep 17 00:00:00 2001 From: Lyrcaxis Date: Sun, 16 Mar 2025 18:54:57 +0200 Subject: [PATCH 5/5] Made cached EOG/Control tokens be `HashSet` for quicker lookup --- LLama/Native/LLamaToken.cs | 4 ++-- LLama/Native/SafeLlamaModelHandle.cs | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index 2b738ce8..c1b85f2f 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -99,7 +99,7 @@ public bool IsControl(SafeLlamaModelHandle model) /// public bool IsControl(SafeLlamaModelHandle.Vocabulary vocab) { - return vocab.ControlTokens.Contains(this); + return vocab.ControlTokens.Contains((int) this); } /// @@ -119,7 +119,7 @@ public bool IsEndOfGeneration(SafeLlamaModelHandle model) /// public bool IsEndOfGeneration(SafeLlamaModelHandle.Vocabulary vocab) { - return vocab.EOGTokens.Contains(this); + return vocab.EOGTokens.Contains((int) this); } /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 1897dd4b..63cc502e 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -640,12 +640,12 @@ public sealed class Vocabulary /// /// Contains unique tokens that are supposed to end the generation (e.g.: EOS, EOT, etc) /// - internal readonly HashSet EOGTokens; + internal readonly HashSet EOGTokens; /// /// Contains unique tokens that exist for inference control rather than text output /// - internal readonly HashSet ControlTokens; + internal readonly HashSet ControlTokens; internal unsafe Vocabulary(SafeLlamaModelHandle model) { @@ -687,8 +687,8 @@ internal unsafe Vocabulary(SafeLlamaModelHandle model) } ); - EOGTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token))); - ControlTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token))); + EOGTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token)).Select(x => (int) x)); + ControlTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token)).Select(x => (int) x)); } private static LLamaToken? Normalize(LLamaToken token)