diff --git a/bindings/python/benches/test_tiktoken.py b/bindings/python/benches/test_tiktoken.py index 3fdad5daf..08c7c9dad 100644 --- a/bindings/python/benches/test_tiktoken.py +++ b/bindings/python/benches/test_tiktoken.py @@ -63,9 +63,11 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document out = enc.encode("This is a test") hf_enc = Tokenizer.from_pretrained(model) + hf_enc.pre_tokenizer = None # apparently backtracking does not work without this out2 = hf_enc.encode("This is a test", add_special_tokens=False).ids - - assert out == out2, "sanity check" + print([hf_enc.decode([k]) for k in out2]) + print([hf_enc.decode([k]) for k in out]) + assert out == out2, f"sanity check {out} == {out2}, {hf_enc.decode(out)} == {hf_enc.decode(out2)}" start = time.perf_counter_ns() enc.encode_ordinary_batch(documents, num_threads=num_threads) @@ -74,9 +76,8 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document readable_size, unit = format_byte_size(num_bytes / (end - start) * 1e9) print(f"tiktoken \t{readable_size} / s") - start = time.perf_counter_ns() - hf_enc.encode_batch_fast(documents) + hf_enc.encode_batch(documents) end = time.perf_counter_ns() readable_size, unit = format_byte_size(num_bytes / (end - start) * 1e9) print(f"huggingface \t{readable_size} / s") diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 2f4dba825..65da080ff 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -8,6 +8,7 @@ use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; use serde::{Deserialize, Serialize}; +use tk::models::backtracking_bpe::{BacktrackingBpe, BacktrackingBpeBuilder}; use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE}; use tk::models::unigram::Unigram; use tk::models::wordlevel::WordLevel; @@ -39,6 +40,10 @@ impl PyModel { .into_pyobject(py)? .into_any() .into(), + ModelWrapper::BacktrackingBpe(_) => Py::new(py, (PyBacktrackingBpe {}, base))? + .into_pyobject(py)? + .into_any() + .into(), ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))? .into_pyobject(py)? .into_any() @@ -572,6 +577,147 @@ impl PyBPE { } } +#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BacktrackingBpe")] +struct PyBacktrackingBpe {} + +impl PyBacktrackingBpe { + fn with_builder( + mut builder: BacktrackingBpeBuilder, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult<(Self, PyModel)> { + if let Some(kwargs) = kwargs { + for (key, value) in kwargs { + let key: String = key.extract()?; + match key.as_ref() { + "unk_token" => { + if let Some(unk) = value.extract()? { + builder = builder.unk_token(unk); + } + } + "fuse_unk" => builder = builder.fuse_unk(value.extract()?), + "byte_fallback" => builder = builder.byte_fallback(value.extract()?), + _ => println!("Ignored unknown kwarg option {}", key), + }; + } + } + + match builder.build() { + Err(e) => Err(exceptions::PyException::new_err(format!( + "Error while initializing BPE: {}", + e + ))), + Ok(bpe) => Ok((PyBacktrackingBpe {}, bpe.into())), + } + } +} + +#[pymethods] +impl PyBacktrackingBpe { + #[new] + #[pyo3( + signature = (vocab=None, merges=None, **kwargs), + text_signature = "(self, vocab=None, merges=None, dropout=None, unk_token=None)")] + fn new( + py: Python<'_>, + vocab: Option, + merges: Option, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult<(Self, PyModel)> { + if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) { + return Err(exceptions::PyValueError::new_err( + "`vocab` and `merges` must be both specified", + )); + } + + let mut builder = BacktrackingBpe::builder(); + if let (Some(vocab), Some(merges)) = (vocab, merges) { + match (vocab, merges) { + (PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => { + builder = builder.vocab_and_merges(vocab, merges); + } + _ => { + return Err(exceptions::PyValueError::new_err( + "`vocab` and `merges` must be both be from memory or both filenames", + )); + } + } + } + + PyBacktrackingBpe::with_builder(builder, kwargs) + } + + /// Read a :obj:`vocab.json` and a :obj:`merges.txt` files + /// + /// This method provides a way to read and parse the content of these files, + /// returning the relevant data structures. If you want to instantiate some BPE models + /// from memory, this method gives you the expected input from the standard files. + /// + /// Args: + /// vocab (:obj:`str`): + /// The path to a :obj:`vocab.json` file + /// + /// merges (:obj:`str`): + /// The path to a :obj:`merges.txt` file + /// + /// Returns: + /// A :obj:`Tuple` with the vocab and the merges: + /// The vocabulary and merges loaded into memory + #[staticmethod] + #[pyo3(text_signature = "(self, vocab, merges)")] + fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> { + BacktrackingBpe::read_file(vocab, merges).map_err(|e| { + exceptions::PyException::new_err(format!( + "Error while reading vocab & merges files: {}", + e + )) + }) + } + + /// Instantiate a BPE model from the given files. + /// + /// This method is roughly equivalent to doing:: + /// + /// vocab, merges = BPE.read_file(vocab_filename, merges_filename) + /// bpe = BPE(vocab, merges) + /// + /// If you don't need to keep the :obj:`vocab, merges` values lying around, + /// this method is more optimized than manually calling + /// :meth:`~tokenizers.models.BPE.read_file` to initialize a :class:`~tokenizers.models.BPE` + /// + /// Args: + /// vocab (:obj:`str`): + /// The path to a :obj:`vocab.json` file + /// + /// merges (:obj:`str`): + /// The path to a :obj:`merges.txt` file + /// + /// Returns: + /// :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files + #[classmethod] + #[pyo3(signature = (vocab, merges, **kwargs))] + #[pyo3(text_signature = "(cls, vocab, merge, **kwargs)")] + fn from_file( + _cls: &Bound<'_, PyType>, + py: Python, + vocab: &str, + merges: &str, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult> { + let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| { + exceptions::PyException::new_err(format!("Error while reading BPE files: {}", e)) + })?; + Py::new( + py, + PyBacktrackingBpe::new( + py, + Some(PyVocab::Vocab(vocab)), + Some(PyMerges::Merges(merges)), + kwargs, + )?, + ) + } +} + /// An implementation of the WordPiece algorithm /// /// Args: diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index dacb96298..01d786fc6 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -68,6 +68,8 @@ fancy-regex = { version = "0.14", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" +fnv = "1.0.7" +aneubeck-daachorse = "1.1.1" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/benches/llama3.rs b/tokenizers/benches/llama3.rs index 77af3bd63..a45e0caac 100644 --- a/tokenizers/benches/llama3.rs +++ b/tokenizers/benches/llama3.rs @@ -2,26 +2,51 @@ extern crate criterion; use criterion::{Criterion, Throughput}; +use itertools::Itertools; +use tokenizers::models::backtracking_bpe; +use tokenizers::PreTokenizerWrapper; use tokenizers::Tokenizer; pub fn llama3(c: &mut Criterion) { let data = std::fs::read_to_string("data/big.txt").unwrap(); let mut group = c.benchmark_group("llama3-encode"); group.throughput(Throughput::Bytes(data.bytes().len() as u64)); - group.bench_function("llama3-offsets", |b| { - let tokenizer = - Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap(); + + group.bench_function("llama3-backtracking", |b| { + let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); + let data: Vec<_> = data.lines().collect(); + let add_special_tokens = false; + b.iter(|| { + tokenizer + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) + .unwrap() + }) + }); + + group.bench_function("llama3-backtracking-no-pretok", |b| { + let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); + tokenizer.with_pre_tokenizer(None::); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch_char_offsets(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); - group.bench_function("llama3-nooffsets", |b| { - let tokenizer = - Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap(); + + group.bench_function("llama3-encode_batch_fast", |b| { + let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); + let data: Vec<_> = data.lines().collect(); + let add_special_tokens = false; + b.iter(|| { + tokenizer + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) + .unwrap() + }) + }); + group.bench_function("llama3-encode_batch", |b| { + let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { @@ -30,13 +55,14 @@ pub fn llama3(c: &mut Criterion) { .unwrap() }) }); + group.finish(); } criterion_group! { - name = bert_benches; + name = llama; config = Criterion::default().sample_size(10); targets = llama3 } -criterion_main!(bert_benches); +criterion_main!(llama); diff --git a/tokenizers/src/models/backtracking_bpe/backtracking_state.rs b/tokenizers/src/models/backtracking_bpe/backtracking_state.rs new file mode 100644 index 000000000..98fe15381 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/backtracking_state.rs @@ -0,0 +1,49 @@ +use super::bitfield::BitField; + +/// This can be thought of as a lazy variation of the dynamic programming approach. +/// It only computes those states which have to be visited in order to compute the tokenization +/// for a given input text. +/// It keeps track of visited states in a bitfield and only remembers the tokenization +/// of the currently processed dynamic programming state. +/// +/// The biggest downside of this approach is that the search for the longest leftmost match (the firt token?) +/// has to be reset at every (backtracking) step which is still a net win in practice compared to other approaches. +#[derive(Clone, PartialEq)] +pub struct BacktrackState<'a> { + pub(crate) text: &'a [u8], + pub(crate) tokens: Vec, // len of the tezt / 3 + pub(crate) next_token: Option, // bpe.next_match(text) wich is longest_searcher.leftmost_find_iter(text)'s first match value + pub(crate) pos: usize, // current pos in the text? + pub(crate) bitfield: BitField, // keeps track of token boundaries? keeps track of all the valid tokenization positions and making the runtime linear in the input length. +} + +impl<'a> BacktrackState<'a> { + pub(crate) fn new(text: &'a [u8], next_token: Option) -> Self { + Self::with_capacity(text, next_token, text.len() / 3) + } + + pub(crate) fn with_capacity(text: &'a [u8], next_token: Option, cap: usize) -> Self { + Self { + text, + tokens: Vec::with_capacity(cap), + next_token, + pos: 0, + bitfield: BitField::new(text.len() + 1), + } + } + pub(crate) fn count(&self) -> usize { + self.tokens.len() + } + + pub(crate) fn pos(&self) -> usize { + self.pos + } + + pub(crate) fn last_token(&self) -> Option { + self.tokens.last().copied() + } + + pub(crate) fn into_tokens(self) -> Vec { + self.tokens + } +} diff --git a/tokenizers/src/models/backtracking_bpe/bitfield.rs b/tokenizers/src/models/backtracking_bpe/bitfield.rs new file mode 100644 index 000000000..832965931 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/bitfield.rs @@ -0,0 +1,57 @@ +/// Small helper to manage a bit field which supports predecessor and successor queries with a simple scan implementation. +/// This is sufficient for our use case, since two one bits will be at most 128 bits apart. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct BitField { + bitfield: Vec, +} + +impl BitField { + /// All bits are initialized to 1. + pub(crate) fn new(bits: usize) -> Self { + Self { + bitfield: vec![u64::MAX; (bits + 63) / 64], + } + } + + pub(crate) fn is_set(&self, bit: usize) -> bool { + let (word, bit) = (bit / 64, bit % 64); + self.bitfield[word] & (1 << bit) != 0 + } + + pub(crate) fn clear(&mut self, bit: usize) { + let (word, bit) = (bit / 64, bit % 64); + self.bitfield[word] &= !(1 << bit); + } + + pub(crate) fn successor(&self, bit: usize) -> usize { + let (mut word_idx, bit_idx) = (bit / 64, bit % 64); + let word = self.bitfield[word_idx] >> bit_idx; + if word != 0 { + word.trailing_zeros() as usize + bit + } else { + loop { + word_idx += 1; + let word = self.bitfield[word_idx]; + if word != 0 { + break word.trailing_zeros() as usize + word_idx * 64; + } + } + } + } + + pub(crate) fn predecessor(&self, bit: usize) -> usize { + let (mut word_idx, bit_idx) = (bit / 64, bit % 64); + let word = self.bitfield[word_idx] << (63 - bit_idx); + if word != 0 { + bit - word.leading_zeros() as usize + } else { + loop { + word_idx -= 1; + let word = self.bitfield[word_idx]; + if word != 0 { + break word_idx * 64 + 63 - word.leading_zeros() as usize; + } + } + } + } +} diff --git a/tokenizers/src/models/backtracking_bpe/mod.rs b/tokenizers/src/models/backtracking_bpe/mod.rs new file mode 100644 index 000000000..e84327b6d --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/mod.rs @@ -0,0 +1,6 @@ +mod backtracking_state; +mod bitfield; +mod model; +mod serialization; + +pub use model::*; diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs new file mode 100644 index 000000000..7a0a38110 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -0,0 +1,896 @@ +use super::bitfield::BitField; +use super::{super::bpe::trainer::BpeTrainer, super::bpe::Error, super::OrderedVocabIter}; +use crate::decoders::byte_level::{BYTES_CHAR, CHAR_BYTES}; +use crate::models::bpe::{MergeMap, Pair, BPE}; +use crate::models::find_hash_factor_for_dictionary; +use crate::tokenizer::{Model, Result, Token}; +use crate::utils::iter::ResultShunt; +use crate::{pre_tokenizers, Decoder}; +use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder}; +use fnv::{FnvHashMap, FnvHasher}; +use itertools::Itertools; +use regex_syntax::ast::print; +use serde_json::Value; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::hash::{Hash, Hasher}; +use std::ops::Range; +use std::{ + collections::HashMap, + fs::File, + io::prelude::*, + io::{BufRead, BufReader}, + path::{Path, PathBuf}, +}; +pub type Vocab = HashMap; +type VocabR = HashMap; +pub type Merges = Vec<(String, String)>; + +use super::backtracking_state::BacktrackState; + +struct Config { + files: Option<(String, String)>, + vocab: Vocab, + merges: Merges, + dropout: Option, + unk_token: Option, + fuse_unk: bool, + byte_fallback: bool, +} + +pub struct BacktrackingBpeBuilder { + config: Config, +} + +impl Default for BacktrackingBpeBuilder { + fn default() -> Self { + Self { + config: Config { + files: None, + vocab: HashMap::new(), + merges: vec![], + dropout: None, + unk_token: None, + fuse_unk: false, + byte_fallback: false, + }, + } + } +} + +/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. +#[derive(PartialEq, Clone)] +pub struct BacktrackingBpe { + /// All the decoded tokens concatenated into? used to build the aho corasick searchers + all_tokens: Vec, + /// Start index of each token in all_tokens. + /// The end is simply the next entry in this vector. + token_starts: Vec, + /// Mapping from hash of token to token id. + bytes_hash_to_token: FnvHashMap, + /// The two tokens from which the token got merged. + /// If the token is an original one, than the two tokens point back to itself. + split_table: Vec<(u32, u32)>, + /// Mapping from a pair of tokens to a merged token if such a merged token exists. + pair_lookup: FnvHashMap<(u32, u32), u32>, + /// An aho corasick automaton to find the next longest token in a byte sequence. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + longest_searcher: DoubleArrayAhoCorasick, + /// An aho corasick automaton to find ALL tokens in a byte sequence. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + pub(crate) overlapping_searcher: DoubleArrayAhoCorasick, + /// An aho corasick automaton to find ALL tokens in a byte sequence which is being processed in reverse order. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + pub(crate) overlapping_searcher_rev: DoubleArrayAhoCorasick, + /// Mapping from a token to the next longest prefix token. + /// This is in principle information represented by the AhoCorasick automaton. + /// But we don't have efficient access to it and therefore store it here again. + /// If there is none, then the value is set to u32::MAX. + next_prefix_match: Vec, + /// Hash factor used to prevent hash collisions. + hash_factor: u64, + pub vocab: Vocab, + pub vocab_r: VocabR, + unk_token: Option, + pub merges: MergeMap, +} + +use std::fmt; + +// Manually implement the Debug trait to exclude the `cache` field +impl fmt::Debug for BacktrackingBpe { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("BacktrackingBpe") + .field("vocab", &self.vocab) + .field("vocab_r", &self.vocab_r) + .field("split_table", &self.split_table) + .field("token_starts", &self.token_starts) + .field("pair_lookup", &self.pair_lookup) + .finish() + } +} + +impl BacktrackingBpeBuilder { + /// Constructs a new `BacktrackingBpeBuilder`. + pub fn new() -> Self { + Self::default() + } + + /// Set the input files. + #[must_use] + pub fn files(mut self, vocab: String, merges: String) -> Self { + self.config.files = Some((vocab, merges)); + self + } + + /// Set the vocab (token -> ID) and merges mappings. + #[must_use] + pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self { + self.config.vocab = vocab; + self.config.merges = merges; + self + } + + /// Use [dropout](https://arxiv.org/abs/1910.13267) with the model. + #[must_use] + pub fn dropout(mut self, dropout: f32) -> Self { + self.config.dropout = Some(dropout); + self + } + + /// Set the `UNK` token for the vocab. + #[must_use] + pub fn unk_token(mut self, unk_token: String) -> Self { + self.config.unk_token = Some(unk_token); + self + } + + /// Set the `fuse_unk` option. + #[must_use] + pub fn fuse_unk(mut self, fuse_unk: bool) -> Self { + self.config.fuse_unk = fuse_unk; + self + } + + /// Set the `byte_fallback` option. + #[must_use] + pub fn byte_fallback(mut self, byte_fallback: bool) -> Self { + self.config.byte_fallback = byte_fallback; + self + } + + /// Returns a `BacktrackingBpe` model that uses the `BacktrackingBpeBuilder`'s configuration. + pub fn build(mut self) -> Result { + // Validate dropout. + if let Some(p) = self.config.dropout { + if !(0.0..=1.0).contains(&p) { + return Err(Error::InvalidDropout.into()); + } + } + + // Read files if necessary + if let Some((vocab, merges)) = self.config.files { + let (v, m) = BPE::read_file(&vocab, &merges)?; + self.config.vocab = v; + self.config.merges = m; + } + use crate::pre_tokenizers::byte_level::CHAR_BYTES; + let vocab_vec: Vec<_> = self.config + .vocab + .into_iter() + .sorted_unstable_by(|a, b| a.1.cmp(&b.1)) + .map(|(k, v)| { + k.chars().map(|b| CHAR_BYTES[&b] as u8).collect::>() + + }).collect(); + let hash = find_hash_factor_for_dictionary(vocab_vec.clone()); + let backtraching_bpe = BacktrackingBpe::from_dictionary( + vocab_vec.clone(), + Some(self.config.merges), + Some(hash), + ); + Ok(backtraching_bpe) + } +} + +impl Default for BacktrackingBpe { + fn default() -> Self { + Self::builder().build().unwrap() + } +} + +// A helper function to iterate over the tokens in a byte sequence +fn token_iter<'a>(all_tokens: &'a [u8], token_starts: &'a [u32]) -> impl Iterator { + token_starts + .iter() + .tuple_windows() + .map(move |(start, end)| &all_tokens[*start as usize..*end as usize]) +} + +fn next_match(longest_searcher: &DoubleArrayAhoCorasick, text: &[u8]) -> Option { + longest_searcher + .leftmost_find_iter(text) + .map(|m| m.value()) + .next() +} + +fn is_valid_token_pair( + pair_lookup: &FnvHashMap<(u32, u32), u32>, + split_table: &[(u32, u32)], + mut token1: u32, + mut token2: u32, +) -> bool { + // Keep track of the maximum token which can still be chosen across the split point. + let mut limit = u32::MAX; + // println!("checking if {token1}, {token2} is a valid token_pair"); + loop { + // Check whether BPE would choose a different token pair across the split point. + // this is super super important + if let Some(combined) = pair_lookup.get(&(token1, token2)) { + if *combined < limit { + // println!("Done1"); + return false; + } + } + // Reverse the merge operation from BPE. + + // println!("{:?}", split_table); + if token1 > token2 { + limit = token1; + token1 = unsafe { split_table.get_unchecked(token1 as usize).1 }; + if token1 == limit { + limit = token2 + 1; + token2 = unsafe { split_table.get_unchecked(token2 as usize).0 }; + if token2 + 1 == limit { + // println!("Done2"); + return true; + } + } + } else { + limit = token2 + 1; + token2 = unsafe { split_table.get_unchecked(token2 as usize).0 }; + if token2 + 1 == limit { + limit = token1; + token1 = unsafe { split_table.get_unchecked(token1 as usize).1 }; + if token1 == limit { + // println!("Done3"); + return true; + } + } + } + } + +} + +fn token_range(token_starts: &[u32], token_id: u32) -> Range { + unsafe { + *token_starts.get_unchecked(token_id as usize) as usize + ..*token_starts.get_unchecked(token_id as usize + 1) as usize + } +} + +fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) -> &'a [u8] { + &all_tokens[token_range(token_starts, token_id)] +} + +fn hash_bytes(bytes: &[u8], factor: u64) -> u32 { + let mut hasher = FnvHasher::default(); + bytes.hash(&mut hasher); + // Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash. + // To make them unique for the given tokens, we have to add unfortunately another multiplication. + ((hasher.finish().wrapping_mul(factor)) >> 32) as u32 +} +fn find_token_by_bytes( + all_tokens: &[u8], + token_starts: &[u32], + bytes_hash_to_token: &FnvHashMap, + bytes: &[u8], + hash_factor: u64, +) -> Option { + let hash = hash_bytes(bytes, hash_factor); + let token = *bytes_hash_to_token.get(&hash)?; + if token_bytes(all_tokens, token_starts, token) == bytes { + Some(token) + } else { + None + } +} + +/// Converts the merges strings (for example from `merges.txt` file) with the format +/// "{pair_a} {pair_b}" into the format expected by the BacktrackingBpe struct +pub(crate) fn convert_merges_to_hashmap>( + iter: I, + _vocab: &Vocab, +) -> Result { + let mut merges = vec![]; + + let lines = iter.filter(|l| !l.starts_with("#version")); + for (rank, line) in lines.enumerate() { + let parts = line.split(' ').collect::>(); + if parts.len() != 2 { + return Err(Error::BadMerges(rank + 1).into()); + } + + merges.push((parts[0].to_string(), parts[1].to_string())); + } + + Ok(merges) +} + +impl BacktrackingBpe { + /// Initialize a `BacktrackingBpeBuilder`. + pub fn builder() -> BacktrackingBpeBuilder { + BacktrackingBpeBuilder::new() + } + + /// Create a new BacktrackingBpe model with the given vocab and merges. + pub fn new(vocab: Vocab, merges: Merges) -> Self { + Self::builder() + .vocab_and_merges(vocab, merges) + .build() + .unwrap() + } + + fn bitfield_into_tokens(&self, bytes: &[u8], bitfield: BitField, count: usize) -> Vec { + let mut encoded = Vec::with_capacity(count); + let mut start = 0; + while start < bytes.len() { + let end = bitfield.successor(start + 1); + // println!("bitfield's successor {:?}", &bytes[start..end]); + let token = self + .find_token_by_bytes(&bytes[start..end]) + .expect(&format!( + "Could not convert bytes to tokens for bytes: [{:?}]", + bytes.into_iter().map(|b| BYTES_CHAR[b]).join("") + )); + encoded.push(token); + start = end; + } + encoded + } + + fn encode_into_bitfield(&self, bytes: &[u8]) -> (BitField, usize) { + // Reserve for every byte a bit in the bitfield. + let mut bitfield = BitField::new(bytes.len() + 1); + let mut heap = BinaryHeap::with_capacity(bytes.len() * 2); + heap.extend((0..bytes.len().saturating_sub(1)).filter_map(|i| { + self.find_token_by_bytes(&bytes[i..i + 2]) + .map(|e| Reverse((e, i as u32))) + })); + let mut count = bytes.len(); + while let Some(Reverse((token, start))) = heap.pop() { + let start = start as usize; + if !bitfield.is_set(start) { + continue; + } + let mid = bitfield.successor(start + 1); + if mid >= bytes.len() { + continue; + } + let end = bitfield.successor(mid + 1); + if self.token_len(token) != end - start { + continue; + } + bitfield.clear(mid); + count -= 1; + if end < bytes.len() { + let new_end = bitfield.successor(end + 1); + if let Some(e) = self.find_token_by_bytes(&bytes[start..new_end]) { + heap.push(Reverse((e, start as u32))); + } + } + if start > 0 { + let new_start = bitfield.predecessor(start - 1); + if let Some(e) = self.find_token_by_bytes(&bytes[new_start..end]) { + heap.push(Reverse((e, new_start as u32))); + } + } + } + (bitfield, count) + } + + pub fn encode_via_bitfield(&self, text: &[u8]) -> Vec { + let (bitfield, count) = self.encode_into_bitfield(text); + self.bitfield_into_tokens(text, bitfield, count) + } + + /// Construct a BytePairEncoding instance from an iterator that enumerates all tokens. + /// A suitable hash factor may be necessary to prevent hash collisions, which can be + /// found using [`find_hash_factor_for_dictionary`]. + /// + /// The recommended approach is to store the serialized value and reuse that, + /// to prevent repeating the cost of computing the hash factor and encoding. + pub fn from_dictionary( + tokens: impl IntoIterator>, + merges: Option, + hash_factor: Option, + ) -> Self { + let hash_factor = hash_factor + .inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero")) + .unwrap_or(1); + let mut all_tokens = Vec::new(); + let mut all_tokens_rev = Vec::new(); + let mut token_starts = vec![0]; // The begin byte index of each token in all_tokens. + let mut bytes_hash_to_token = FnvHashMap::default(); + let mut merge_map: HashMap = HashMap::new(); + for (i, token) in tokens.into_iter().enumerate() { + use pre_tokenizers::byte_level::ByteLevel; + info!( + "token byte: {:?}, {i}", + ByteLevel::default() + .decode_chain(unsafe { vec![String::from_utf8_unchecked(token.clone())] }) + .unwrap() + ); + bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32); + all_tokens_rev.extend(token.iter().copied().rev()); + all_tokens.extend(token); + token_starts.push(all_tokens.len() as u32); + } + assert_eq!(bytes_hash_to_token.len() + 1, token_starts.len(), "Some tokens are not unique under the hash function!"); // TODO maybe this check is needed? + let longest_searcher = DoubleArrayAhoCorasickBuilder::new() + .match_kind(aneubeck_daachorse::MatchKind::LeftmostLongest) + .build(token_iter(&all_tokens, &token_starts)) + .expect("failed to build AhoCorasick"); + + let overlapping_searcher = + DoubleArrayAhoCorasick::::new(token_iter(&all_tokens, &token_starts)).expect(""); + let overlapping_searcher_rev = + DoubleArrayAhoCorasick::::new(token_iter(&all_tokens_rev, &token_starts)) + .expect(""); + + let next_prefix_match: Vec<_> = token_iter(&all_tokens, &token_starts) + .map(|token| { + next_match(&longest_searcher, &token[0..token.len() - 1]).unwrap_or(u32::MAX) + }) + .collect(); + + use pre_tokenizers::byte_level::BYTES_CHAR; + let vocab: HashMap = token_iter(&all_tokens, &token_starts) + .enumerate() + .map(|(id, bytes)| { + ( + bytes + .iter() + .map(|b| BYTES_CHAR[b]) + .collect::(), + id as u32, + ) + }) + .collect(); + + let vocab_r: HashMap = token_iter(&all_tokens, &token_starts) + .enumerate() + .map(|(id, bytes)| { + ( + id as u32, + bytes + .iter() + .map(| b| BYTES_CHAR[b]) + .collect::(), + ) + }) + .collect(); + + let mut split_table = vec![]; + let mut pair_lookup = FnvHashMap::default(); + + // // First option, use the input merge table. + // if let Some(ref merges) = merges { + // for (index, pair) in merges.into_iter().enumerate() { + // let token1 = &pair.0.clone(); + // let token2 = &pair.1.clone(); + // // TODO something is weird here + // if token1.len() ==1{ + // split_table.push((vocab[token1], vocab[token1])); + // } + // if token2.len() == 1 { + // split_table.push((vocab[token2], vocab[token2])); + // } + // let id1 = vocab[token1]; + // let id2 = vocab[token2]; + // let new_token = format!("{}{}", token1, &token2); + // let new_id = vocab + // .get(&new_token) + // .ok_or(Error::MergeTokenOutOfVocabulary(new_token)); + // if let Ok(id) = new_id { + // pair_lookup.insert((id1, id2), *id); + // split_table.push((id1, id2)); + // merge_map.insert(Pair::from((id1, id2)), (index as u32, *id)); + // } else { + // println!("Token not added?"); + // } + + // // TODO wrong + // } + // split_table.push((merges.len() as u32, merges.len() as u32)); + // } + // Second option, reverse engineer the merge/split table from the vocabulary. + { + for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { + let mut id1 = next_prefix_match[id]; + while id1 != u32::MAX { + let rest = &token[token_range(&token_starts, id1).len()..]; + if let Some(id2) = find_token_by_bytes( + &all_tokens, + &token_starts, + &bytes_hash_to_token, + rest, + hash_factor, + ) { + if id1 < id as u32 + && id2 < id as u32 + && is_valid_token_pair(&pair_lookup, &split_table, id1, id2) + { + pair_lookup.insert((id1, id2), id as u32); + split_table.push((id1, id2)); + merge_map.insert(Pair::from((id1, id2)), (id as u32, id as u32)); + break; + } + } + id1 = next_prefix_match[id1 as usize]; + } + if id1 == u32::MAX { + split_table.push((id as u32, id as u32)); + } + } + }; + + let bpe = Self { + all_tokens, + token_starts, + bytes_hash_to_token, + overlapping_searcher, + overlapping_searcher_rev, + longest_searcher, + next_prefix_match, + pair_lookup, + split_table, + hash_factor, + unk_token: None, + vocab, + vocab_r, + merges: merge_map, + }; + + // A health checkup + for token_id in 0..bpe.num_tokens() as u32 { + let bytes = bpe.token_bytes(token_id); + let strs = bytes.iter().map(|b| char::from(*b)).collect::>(); + println!("Encoding {bytes:?} into bitfield"); + let tokens = bpe.encode_via_bitfield(bytes); + assert_eq!( + tokens, + vec![token_id], + "token {token_id} with bytes {bytes:?} (tokens {strs:?} encodes to {tokens:?} instead of to itself" + ); + } + // println!("{:#?}", bpe); + bpe + } + + /// Initialize a BacktrackingBpeBuilder model from vocab and merges files + pub fn from_file(vocab: &str, merges: &str) -> BacktrackingBpeBuilder { + Self::builder().files(vocab.to_owned(), merges.to_owned()) + } + + /// Read the given files to extract the vocab and merges + pub fn read_file(vocab: &str, merges: &str) -> Result<(Vocab, Merges)> { + // Read vocab.json + let vocab_file = File::open(vocab)?; + let mut vocab_file = BufReader::new(vocab_file); + + let mut buffer = String::new(); + vocab_file.read_to_string(&mut buffer)?; + let json: Value = serde_json::from_str(&buffer)?; + let mut vocab = HashMap::new(); + match json { + Value::Object(m) => { + for (token, id) in m { + if let Value::Number(id) = id { + let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32; + vocab.insert(token, id); + } + } + } + _ => return Err(Box::new(Error::BadVocabulary)), + }; + + // Read merges file + let merge_file = File::open(merges)?; + let merge_file = BufReader::new(merge_file); + let merges = ResultShunt::process(merge_file.lines(), |iter| { + convert_merges_to_hashmap(iter, &vocab) + })??; // TODO correctly process to fill the split and pair lookup + + Ok((vocab, merges)) + } + + /// Return the number of tokens in this BPE dictionary. + pub fn num_tokens(&self) -> usize { + self.token_starts.len() - 1 + } + + /// Converts a token id into its corresponding token bytes. + /// Panics if the token_id is not within the valid 0..num_tokens() range! + pub fn token_bytes(&self, token_id: u32) -> &[u8] { + token_bytes(&self.all_tokens, &self.token_starts, token_id) + } + + pub(crate) fn is_valid_token_pair(&self, token1: u32, token2: u32) -> bool { + is_valid_token_pair(&self.pair_lookup, &self.split_table, token1, token2) + } + + /// Returns the length of the decoded byte slice of a token. + pub fn token_len(&self, token_id: u32) -> usize { + token_range(&self.token_starts, token_id).len() + } + + /// Returns the first longest match in the provided text. + pub(crate) fn next_match(&self, text: &[u8]) -> Option { + next_match(&self.longest_searcher, text) + } + + /// Returns the next token which shares the longest prefix with the specified token. + pub(crate) fn next_prefix(&self, token_id: u32) -> Option { + let prefix = self.next_prefix_match[token_id as usize]; + if prefix == u32::MAX { + None + } else { + Some(prefix) + } + } + + fn find_token_by_bytes(&self, bytes: &[u8]) -> Option { + find_token_by_bytes( + &self.all_tokens, + &self.token_starts, + &self.bytes_hash_to_token, + bytes, + self.hash_factor, + ) + } + + /// Decode a sequence of tokens back to its original byte sequence. + /// Note: we don't return here a str, since not every token sequence corresponds to a valid + /// utf8 sequence. + pub fn decode_tokens(&self, tokens: &[u32]) -> Vec { + let mut text = vec![]; + for token in tokens { + text.extend(self.token_bytes(*token)); + } + text + } + + /// Computes for every prefix of the input text a corresponding last token. + pub(crate) fn encode_all_prefixes(&self, text: &[u8]) -> Vec { + let mut last_token = Vec::with_capacity(text.len()); + let mut state = self.overlapping_searcher.start_state(); + for (pos, c) in text.iter().enumerate() { + let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c); + state = s; + for m in iter { + let new_token = m.value(); + let new_range = m.start()..m.end(); + assert_eq!(new_range.end, last_token.len() + 1); + if new_range.start == 0 { + last_token.push(new_token); + break; + } else { + let prev_token = unsafe { *last_token.get_unchecked(new_range.start - 1) }; + if self.is_valid_token_pair(prev_token, new_token) { + last_token.push(new_token); + break; + } + // println!("Finished encoding prefix") + } + } + } + last_token + } + + /// Counts the number tokens produced when encoding the text. + pub fn count(&mut self, text: &[u8]) -> usize { + let mut enc = BacktrackState::new(text, None); + while self.step(&mut enc).is_some() {} + enc.count() + } + + pub fn encode_via_table(&self, text: &[u8]) -> Vec { + let last_token = self.encode_all_prefixes(text); + let mut encoded = Vec::with_capacity(text.len() / 3); + let mut pos = text.len(); + while pos > 0 { + let token = last_token[pos - 1]; + encoded.push(token); + pos -= self.token_len(token); + } + encoded.reverse(); + encoded + } + + pub fn encode_via_backtracking(&self, text: &[u8]) -> Vec { + let next_token = self.next_match(text); + let mut enc = BacktrackState::new(text, next_token); + while self.step(&mut enc).is_some() {} + enc.into_tokens() + } + + pub fn get_vocab(&self) -> Vocab { + self.vocab.clone() + } + + pub fn get_unk_token(&self) -> &Option { + &self.unk_token + } + + pub fn step(&self, backtrack_state: &mut BacktrackState) -> Option { + let mut token = backtrack_state.next_token?; + let last = backtrack_state.tokens.last().copied(); + loop { + // println!("in step, token: {last:?}, {token}"); + let token_len = self.token_len(token); + let end_pos = backtrack_state.pos + token_len; + if backtrack_state.bitfield.is_set(end_pos) + && last + .map(|last_token| self.is_valid_token_pair(last_token, token)) + .unwrap_or(true) + { + backtrack_state.tokens.push(token); + backtrack_state.pos = end_pos; + // In principle, we could in some cases reuse the leftmost longest match iterator. + // Especially when it has to look ahead, this could save scanning the input multiple times. + // But on average this seems to be slower due to the overhead of storing the iterator as part of the struct. + backtrack_state.next_token = self.next_match(&backtrack_state.text[end_pos..]); + break; + } else if let Some(shorter) = self.next_prefix(token) { + token = shorter; + } else { + // Clearing the bitfield when we pop tokens saves a little bit of work... + backtrack_state.bitfield.clear(backtrack_state.pos); + backtrack_state.tokens.pop(); + backtrack_state.pos -= last.map(|t| self.token_len(t)).unwrap_or(0); + backtrack_state.next_token = last; + break; + } + } + // println!("finished step, token: {last:?}, {token}"); + + backtrack_state.next_token + } + + fn word_to_tokens<'a, 'b: 'a>( + &'a self, + word: &'b Vec, + ) -> impl Iterator + 'a { + word.into_iter() + .map(move |id| Token::new(*id, self.vocab_r[&id].clone(), (0usize, 0usize))) + // TODO offsets should be easy to integrate as well! + } +} +impl Model for BacktrackingBpe { + type Trainer = BpeTrainer; + + fn get_vocab(&self) -> HashMap { + self.vocab.clone() + } + + fn get_vocab_size(&self) -> usize { + self.vocab.len() + } + + fn tokenize(&self, sequence: &str) -> Result> { + if sequence.is_empty() { + return Ok(vec![]); + } + let byte_text = sequence.as_bytes(); + let word = self.encode_via_backtracking(byte_text); + Ok(self.word_to_tokens(&word).collect()) + } + + fn token_to_id(&self, token: &str) -> Option { + self.vocab.get(token).copied() + } + + fn id_to_token(&self, id: u32) -> Option { + Some(self.vocab_r[&id].clone()) + } + + fn save(&self, folder: &Path, name: Option<&str>) -> Result> { + let vocab_file_name = match name { + Some(name) => format!("{name}-vocab.json"), + None => "vocab.json".to_string(), + }; + + // Write vocab.json + let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())] + .iter() + .collect(); + let mut vocab_file = File::create(&vocab_path)?; + let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r); + let serialized = serde_json::to_string(&order_vocab_iter)?; + vocab_file.write_all(serialized.as_bytes())?; + Ok(vec![vocab_path]) + // Ok(vec![vocab_path, merges_path]) + } + + fn get_trainer(&self) -> BpeTrainer { + BpeTrainer::default() + } +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn my_example() { + let tokens = [ + "a", "b", "c", // 1 character each + "aac", "ac", "cc", "cca", "aacc", "aaccca", "acca", "acc", "aa", "aaa", + "aaaa", // 2 characters each + ]; + let mut bpe = + BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None, None); + // bpe.encode_via_backtracking(b"baacca"); + let tokens = bpe.tokenize("aaaacc").unwrap(); + println!("{:?}", bpe.tokenize("aaaacc")); + assert_eq!( + tokens, + vec![ + Token { + id: 12, + value: String::from("aaa"), + offsets: (0, 0) + }, + Token { + id: 10, + value: String::from("acc"), + offsets: (0, 0) + } + ] + ); + println!("{:?}", bpe.tokenize("baaaaccca")); + let tokens = bpe.tokenize("baaaaccca").unwrap(); + assert_eq!( + tokens, + vec![ + Token { + id: 1, + value: String::from("b"), + offsets: (0, 0) + }, + Token { + id: 12, + value: String::from("aaa"), + offsets: (0, 0) + }, + Token { + id: 4, + value: String::from("ac"), + offsets: (0, 0) + }, + Token { + id: 6, + value: String::from("cca"), + offsets: (0, 0) + } + ] + ); + bpe.encode_via_backtracking(b"baaaaccca"); + let tokens = [ + "a", "b", "c", // 1 character each + "acca", "cc", "ac", "aac", "cca", + ]; + let mut bpe = + BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None, None); + bpe.encode_via_backtracking(b"baacca"); + } +} diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs new file mode 100644 index 000000000..eae5a65d0 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -0,0 +1,264 @@ +use super::{ + super::bpe::Pair, super::OrderedVocabIter, convert_merges_to_hashmap, BacktrackingBpe, + BacktrackingBpeBuilder, +}; +use serde::{ + de::{Error, MapAccess, Visitor}, + ser::SerializeStruct, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::collections::HashMap; + +impl Serialize for BacktrackingBpe { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut model = serializer.serialize_struct("BacktrackingBpe", 8)?; + + // Start by small fields + model.serialize_field("type", "BacktrackingBpe")?; + + // Then the large ones + let mut merges: Vec<(&Pair, &u32)> = self + .merges + .iter() + .map(|(pair, (rank, _))| (pair, rank)) + .collect(); + merges.sort_unstable_by_key(|k| *k.1); + let merges = merges + .into_iter() + .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) + .collect::>(); + let ordered_vocab = OrderedVocabIter::new(&self.vocab_r); + model.serialize_field("vocab", &ordered_vocab)?; + model.serialize_field("merges", &merges)?; + + model.end() + } +} + +impl<'de> Deserialize<'de> for BacktrackingBpe { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_struct( + "BacktrackingBpe", + &["type", "dropout", "unk_token", "vocab", "merges"], + BacktrackingBpeVisitor, + ) + } +} + +struct BacktrackingBpeVisitor; +impl<'de> Visitor<'de> for BacktrackingBpeVisitor { + type Value = BacktrackingBpe; + + fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(fmt, "struct BacktrackingBpe to be the type") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: MapAccess<'de>, + { + let mut builder = BacktrackingBpeBuilder::new(); + let mut vocab: Option> = None; + + #[derive(Debug, Deserialize)] + #[serde(untagged)] + enum MergeType { + Tuple(Vec<(String, String)>), + Legacy(Vec), + } + let mut merges: Option = None; + while let Some(key) = map.next_key::()? { + match key.as_ref() { + "dropout" => { + if let Some(dropout) = map.next_value()? { + builder = builder.dropout(dropout); + } + } + "unk_token" => { + if let Some(unk) = map.next_value()? { + builder = builder.unk_token(unk); + } + } + "vocab" => vocab = Some(map.next_value()?), + "merges" => merges = Some(map.next_value()?), + "type" => match map.next_value()? { + "BacktrackingBpe" => {} + "BPE" => { + info!("Type is BPE but initializing a backtracking BPE") + } + u => { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(u), + &"BacktrackingBpe should have been found", + )) + } + }, + field => { + info!("Ignoring unused field {:?}", field); // TODO make it into a logger + // Ensure the value is consumed to maintain valid deserialization + let _ = map.next_value::()?; + } + } + } + if let (Some(vocab), Some(merges)) = (vocab, merges) { + let merges = match merges { + MergeType::Tuple(merges) => merges, + MergeType::Legacy(merges) => convert_merges_to_hashmap(merges.into_iter(), &vocab) + .map_err(|e| Error::custom("Error in convert merges to hashmap"))?, + }; + builder = builder.vocab_and_merges(vocab, merges); + let model = builder.build().map_err(|e| { + Error::custom(format!("Error building the backtraciing BPE {:?}", e)) + })?; + println!("deserialized the model"); + Ok(model) + } else { + Err(Error::custom("Missing vocab/merges")) + } + } +} + +#[cfg(test)] +mod test { + use std::process::exit; + + use super::*; + use crate::models::bpe::Vocab; + use crate::tokenizer::Tokenizer; + + #[test] + fn test_serialization() { + let bpe_string = r#"{ + "type": "BPE", + "dropout": null, + "unk_token": "", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "a": 0, + "b": 1, + "ab": 2, + "aba": 3, + "abb": 4, + "bb":5, + "abbb":6 + }, + "merges": [ + ["a", "b"], + ["ab", "a"], + ["ab", "b"], + ["b", "b"], + ["ab", "bb"] + ] + }"#; + // [(0, 1), (2, 0), (2, 1), (2, 5), (1, 1), (5, 5), (0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)] + // above is the expected split table. In string equivalent: + // ["a,b", "c,a", "c,b", "c,bb", "b,b", "bb, bb", "a,a", "b,b". "ab,ab", "abb, abb", "bb, bb"] + let reconstructed: Result = + serde_json::from_str(&bpe_string); + match reconstructed { + Ok(reconstructed) => { + println!("Good. Now doing backtracking:"); + println!("{:?}", reconstructed.encode_via_backtracking(b"aab")); + } + Err(err) => { + println!("Error deserializing: {:?}", err); + } + } + println!("End of my example"); + let vocab: Vocab = [ + ("a".into(), 0), + ("b".into(), 1), + ("ab".into(), 2), + ("aba".into(), 3), + ("abb".into(), 4), + ] + .iter() + .cloned() + .collect(); + let bpe = BacktrackingBpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())]) + .unk_token("".to_string()) + .build() + .unwrap(); + + println!( + "First encoding: {:?}", + bpe.encode_via_backtracking(b"aabbab") + ); + + let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","fuse_unk":false,"byte_fallback":false,"vocab":{"a":1,"b":2,"ab":3},"merges":["a b"]}"#; + let legacy = serde_json::from_str(legacy); + match legacy { + Ok(_) => { + println!("Good"); + assert_eq!(bpe, legacy.unwrap()); + } + Err(err) => { + println!("Error: {:?}", err); + } + } + + let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BPE","vocab":{"ab":0,"a":1,"b":2},"merges":[["a","b"]]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); + assert_eq!(bpe, reconstructed); // TODO failing for now! + + // With a space in the token + let vocab: Vocab = [ + ("".into(), 0), + ("a".into(), 1), + ("b c d".into(), 2), + ("ab c d".into(), 3), + ] + .iter() + .cloned() + .collect(); + let bpe = BacktrackingBpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())]) + .unk_token("".to_string()) + .build() + .unwrap(); + let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BacktrackingBpe","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); + assert_eq!(bpe, reconstructed); + } + + #[cfg(feature = "http")] + #[test] + fn test_from_pretrained() { + let bpe = Tokenizer::from_pretrained("gpt2", None).unwrap(); + let bpe_string = serde_json::to_string(&bpe.get_model()).unwrap(); + let reconstructed: Result = + serde_json::from_str(&bpe_string); + match reconstructed { + Ok(reconstructed) => { + println!("Good from_pretrained reconstruction"); + println!( + "{:?}", + reconstructed.encode_via_backtracking(b"Hello, my name is") + ); + // assert_eq!(bpe, reconstructed); + } + Err(err) => { + println!("Error deserializing: {:?}", err); + } + } + } +} diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index f0d40b2df..97d8e12ad 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -4,9 +4,9 @@ use std::{iter, mem}; mod model; mod serialization; pub mod trainer; -mod word; +pub mod word; -type Pair = (u32, u32); +pub(crate) type Pair = (u32, u32); /// Errors that can be encountered while using or constructing a `BPE` model. #[derive(thiserror::Error, Debug)] @@ -79,4 +79,4 @@ where // Re-export pub use model::*; pub use trainer::*; -use word::*; +pub use word::*; diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 217c37e90..c2f9abf9a 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,4 +1,5 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; + use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; @@ -13,7 +14,7 @@ use std::{ }; pub type Vocab = HashMap; -type VocabR = HashMap; +pub type VocabR = HashMap; pub type MergeMap = HashMap; pub type Merges = Vec<(String, String)>; diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index a1a0aba76..ec449de5e 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -1,6 +1,7 @@ #![allow(clippy::map_entry)] use super::{Pair, WithFirstLastIterator, Word, BPE}; +use crate::models::Bpe; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressStyle}; @@ -432,7 +433,7 @@ impl BpeTrainer { pub fn do_train( &self, word_counts: &HashMap, - model: &mut BPE, + model: &mut Bpe, // add a generic BPE ) -> Result> { let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); @@ -601,29 +602,34 @@ impl BpeTrainer { self.finalize_progress(&progress, merges.len()); // Transfer new vocab & options to model - model.vocab = word_to_id; - model.vocab_r = model - .vocab + let vocabulary = word_to_id.clone(); + let vocab_reversed: HashMap = word_to_id .iter() .map(|(key, val)| (*val, key.to_owned())) .collect(); - model.merges = merges + let merges: HashMap = merges .into_iter() .enumerate() .map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id))) .collect(); - if let Some(prefix) = &self.continuing_subword_prefix { - model.continuing_subword_prefix = Some(prefix.to_owned()); + let continuing_subword_prefix = if let Some(prefix) = &self.continuing_subword_prefix { + Some(prefix.to_owned()) } else { - model.continuing_subword_prefix = None; - } - if let Some(suffix) = &self.end_of_word_suffix { - model.end_of_word_suffix = Some(suffix.to_owned()); + None + }; + let end_of_word_suffix = if let Some(suffix) = &self.end_of_word_suffix { + Some(suffix.to_owned()) } else { - model.end_of_word_suffix = None; - } - + None + }; + model.with( + vocabulary, + vocab_reversed, + merges, + end_of_word_suffix, + continuing_subword_prefix, + ); Ok(self.special_tokens.clone()) } } @@ -633,7 +639,7 @@ impl Trainer for BpeTrainer { /// Train a BPE model fn train(&self, model: &mut BPE) -> Result> { - self.do_train(&self.words, model) + self.do_train(&self.words, &mut Bpe::OriginalBpe(model.to_owned())) } /// Whether we should show progress @@ -675,6 +681,8 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { + use crate::models::Bpe; + use super::{BpeTrainer, Pair, BPE}; use std::collections::HashMap; @@ -700,7 +708,7 @@ mod tests { .show_progress(false) .min_frequency(2) .build(); - let mut model = BPE::default(); + let mut model =Bpe::OriginalBpe(BPE::default()); trainer.do_train(&word_counts, &mut model).unwrap(); // Vocab should contain all of the characters from the `word_counts` mapping @@ -735,7 +743,7 @@ mod tests { .iter() .cloned() .collect(); - assert_eq!(model.vocab, expected_vocab); + assert_eq!(model.get_vocab(), expected_vocab); // The keys in `merges` are pairs of symbols, the values are tuples of (rank, id), // where 'rank' determines the order in which this merge will be applied during @@ -749,7 +757,7 @@ mod tests { .iter() .cloned() .collect(); - assert_eq!(model.merges, expected_merges); + assert_eq!(model.get_merges(), expected_merges); } #[test] fn bpe_test_max_token_length_16() { @@ -781,7 +789,8 @@ mod tests { .show_progress(false) .min_frequency(0) .build(); - let mut model = BPE::default(); + let mut model = Bpe::OriginalBpe(BPE::default()); + trainer.do_train(&long_word_counts, &mut model).unwrap(); let vocab = model.get_vocab(); for token in vocab.keys() { @@ -821,7 +830,7 @@ mod tests { .show_progress(false) .min_frequency(0) .build(); - let mut model = BPE::default(); + let mut model = Bpe::OriginalBpe(BPE::default()); trainer.do_train(&long_word_counts, &mut model).unwrap(); let trained_vocab: HashMap = model.get_vocab(); let expected_vocab: HashMap = [ diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 93b3d9c37..8c951a679 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -53,7 +53,7 @@ impl Symbol { } #[derive(Clone, Default)] -pub(super) struct Word { +pub struct Word { symbols: Vec, } impl std::fmt::Debug for Word { @@ -74,7 +74,7 @@ impl std::fmt::Debug for Word { } impl Word { - pub(super) fn new() -> Self { + pub(crate) fn new() -> Self { Word { symbols: vec![] } } @@ -84,7 +84,7 @@ impl Word { } } - pub(super) fn add(&mut self, c: u32, byte_len: usize) { + pub(crate) fn add(&mut self, c: u32, byte_len: usize) { let (prev, next) = { let len = self.symbols.len() as isize; if let Some(last) = self.symbols.last_mut() { @@ -103,7 +103,7 @@ impl Word { }); } - pub(super) fn merge( + pub(crate) fn merge( &mut self, c1: u32, c2: u32, @@ -251,7 +251,7 @@ impl Word { self.symbols.retain(|s| s.len != 0); } - pub(super) fn get_chars(&self) -> Vec { + pub(crate) fn get_chars(&self) -> Vec { self.symbols.iter().map(|s| s.c).collect() } diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 3a3a91adc..9594faf5b 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -1,15 +1,20 @@ //! Popular tokenizer models. +pub mod backtracking_bpe; pub mod bpe; pub mod unigram; pub mod wordlevel; pub mod wordpiece; use std::collections::HashMap; +use std::hash::{Hash, Hasher}; use std::path::{Path, PathBuf}; +use bpe::{Merges, Vocab, VocabR}; +use itertools::Itertools; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use crate::models::backtracking_bpe::BacktrackingBpe; use crate::models::bpe::{BpeTrainer, BPE}; use crate::models::unigram::{Unigram, UnigramTrainer}; use crate::models::wordlevel::{WordLevel, WordLevelTrainer}; @@ -48,7 +53,6 @@ impl Serialize for OrderedVocabIter<'_> { } else { serializer.collect_map(std::iter::empty::<(&str, u32)>()) }; - if !holes.is_empty() { warn!("The OrderedVocab you are attempting to save contains holes for indices {:?}, your vocabulary could be corrupted !", holes); println!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !"); @@ -61,6 +65,7 @@ impl Serialize for OrderedVocabIter<'_> { #[serde(untagged)] pub enum ModelWrapper { BPE(BPE), + BacktrackingBpe(BacktrackingBpe), // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility // with the versions not including the "type"), since WordLevel is a subset of WordPiece WordPiece(WordPiece), @@ -68,6 +73,84 @@ pub enum ModelWrapper { Unigram(Unigram), } +pub enum Bpe { + OriginalBpe(BPE), + BacktrackingBpe(BacktrackingBpe), +} + +impl Bpe { + fn get_vocab(& self) -> Vocab { + match self { + Bpe::OriginalBpe(model) => model.get_vocab(), + Bpe::BacktrackingBpe(model) => model.get_vocab() + } + } + + fn get_merges(&self) -> HashMap<(u32, u32), (u32, u32)> { + match self { + Bpe::OriginalBpe(model) => model.merges.clone(), + Bpe::BacktrackingBpe(model) => model.merges.clone() + } + } + + fn with( + &mut self, + vocab: Vocab, + vocab_r: VocabR, + merge_map: HashMap<(u32, u32), (u32, u32)>, + end_of_word_suffix: Option, + continous_subword_prefix: Option, + ) -> &mut Self { + match self { + Bpe::OriginalBpe(model) => { + model.vocab = vocab; + model.vocab_r = vocab_r; + model.merges = merge_map; + model.end_of_word_suffix = end_of_word_suffix; + model.continuing_subword_prefix = continous_subword_prefix; + } + Bpe::BacktrackingBpe(model) => { + model.vocab = vocab; + model.vocab_r = vocab_r; + model.merges = merge_map; + } + } + self + } +} + +use fnv::{FnvHashMap, FnvHasher}; + +fn hash_bytes(bytes: &[u8], factor: u64) -> u32 { + let mut hasher = FnvHasher::default(); + bytes.hash(&mut hasher); + // Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash. + // To make them unique for the given tokens, we have to add unfortunately another multiplication. + ((hasher.finish().wrapping_mul(factor)) >> 32) as u32 +} + +// #[cfg(feature = "rand")] +pub fn find_hash_factor_for_dictionary(tokens: impl IntoIterator>) -> u64 { + use std::collections::HashSet; + + use rand::Rng; + + let all_tokens = tokens.into_iter().collect_vec(); + let mut rnd = rand::thread_rng(); + loop { + let factor: u64 = rnd.gen(); + let mut seen = HashSet::new(); + if all_tokens + .iter() + .all(|token| seen.insert(hash_bytes(token, factor))) + { + return factor; + } + } +} + +pub const USE_ORIGINAL_BPE:bool = false; + impl<'de> Deserialize<'de> for ModelWrapper { fn deserialize(deserializer: D) -> std::result::Result where @@ -86,6 +169,7 @@ impl<'de> Deserialize<'de> for ModelWrapper { WordPiece, WordLevel, Unigram, + BacktrackingBpe, } #[derive(Deserialize)] @@ -109,9 +193,12 @@ impl<'de> Deserialize<'de> for ModelWrapper { let helper = ModelHelper::deserialize(deserializer)?; Ok(match helper { ModelHelper::Tagged(model) => match model.variant { - EnumType::BPE => ModelWrapper::BPE( + EnumType::BPE => if USE_ORIGINAL_BPE{ + ModelWrapper::BPE(serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?) + } + else { ModelWrapper::BacktrackingBpe( serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, - ), + )}, EnumType::WordPiece => ModelWrapper::WordPiece( serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, ), @@ -121,11 +208,34 @@ impl<'de> Deserialize<'de> for ModelWrapper { EnumType::Unigram => ModelWrapper::Unigram( serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, ), + EnumType::BacktrackingBpe => ModelWrapper::BacktrackingBpe( + serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, + ), }, ModelHelper::Legacy(value) => { let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; match untagged { - ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe), + ModelUntagged::BPE(bpe) => if !USE_ORIGINAL_BPE { + let vocabulary = bpe + .get_vocab() + .into_keys() + .into_iter() + .map(|token| token.into_bytes()); + let merges = bpe + .merges + .iter() + .map(|(a, _)| { + (bpe.id_to_token(a.0).unwrap(), bpe.id_to_token(a.1).unwrap()) + }) + .collect(); + let vocab_vec: Vec<_> = vocabulary.collect(); + let rng_hash = find_hash_factor_for_dictionary(vocab_vec.clone()); + let backtracking_bpe = + BacktrackingBpe::from_dictionary(vocab_vec.clone(), Some(merges), Some(rng_hash)); + ModelWrapper::BacktrackingBpe(backtracking_bpe) + } else { + ModelWrapper::BPE(bpe) + } ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe), ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe), ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe), @@ -139,6 +249,7 @@ impl_enum_from!(WordLevel, ModelWrapper, WordLevel); impl_enum_from!(WordPiece, ModelWrapper, WordPiece); impl_enum_from!(BPE, ModelWrapper, BPE); impl_enum_from!(Unigram, ModelWrapper, Unigram); +impl_enum_from!(BacktrackingBpe, ModelWrapper, BacktrackingBpe); impl Model for ModelWrapper { type Trainer = TrainerWrapper; @@ -149,6 +260,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.tokenize(tokens), Self::BPE(t) => t.tokenize(tokens), Self::Unigram(t) => t.tokenize(tokens), + Self::BacktrackingBpe(t) => t.tokenize(tokens), } } @@ -158,6 +270,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.token_to_id(token), Self::BPE(t) => t.token_to_id(token), Self::Unigram(t) => t.token_to_id(token), + Self::BacktrackingBpe(t) => t.token_to_id(token), } } @@ -167,6 +280,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.id_to_token(id), Self::BPE(t) => t.id_to_token(id), Self::Unigram(t) => t.id_to_token(id), + Self::BacktrackingBpe(t) => t.id_to_token(id), } } @@ -176,6 +290,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_vocab(), Self::BPE(t) => t.get_vocab(), Self::Unigram(t) => t.get_vocab(), + Self::BacktrackingBpe(t) => t.get_vocab(), } } @@ -185,6 +300,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_vocab_size(), Self::BPE(t) => t.get_vocab_size(), Self::Unigram(t) => t.get_vocab_size(), + Self::BacktrackingBpe(t) => t.get_vocab_size(), } } @@ -194,6 +310,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.save(folder, name), Self::BPE(t) => t.save(folder, name), Self::Unigram(t) => t.save(folder, name), + Self::BacktrackingBpe(t) => t.save(folder, name), } } @@ -203,6 +320,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_trainer().into(), Self::BPE(t) => t.get_trainer().into(), Self::Unigram(t) => t.get_trainer().into(), + Self::BacktrackingBpe(t) => t.get_trainer().into(), } } } @@ -284,7 +402,6 @@ impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer); impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer); impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer); impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer); - #[cfg(test)] mod tests { use super::*; diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8396f1a7b..a8f53f38a 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -44,8 +44,8 @@ lazy_static! { r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" ) .unwrap(); - static ref BYTES_CHAR: HashMap = bytes_char(); - static ref CHAR_BYTES: HashMap = + pub static ref BYTES_CHAR: HashMap = bytes_char(); + pub static ref CHAR_BYTES: HashMap = bytes_char().into_iter().map(|(c, b)| (b, c)).collect(); } diff --git a/tokenizers/src/tokenizer/serialization.rs b/tokenizers/src/tokenizer/serialization.rs index 26d8344f4..411cdedae 100644 --- a/tokenizers/src/tokenizer/serialization.rs +++ b/tokenizers/src/tokenizer/serialization.rs @@ -135,7 +135,7 @@ where builder = builder.with_pre_tokenizer(map.next_value()?); } "model" => { - builder = builder.with_model(map.next_value()?); + builder = builder.with_model(map.next_value()?); // TODO use with_seed to add args / kwargs } "decoder" => { builder = builder.with_decoder(map.next_value()?);