Skip to content

Draft backtrack #1712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ac46243
initial-commit
ArthurZucker Jan 3, 2025
3e884ac
update test
ArthurZucker Jan 3, 2025
0907a4d
update benches as well
ArthurZucker Jan 3, 2025
dca4568
nits
ArthurZucker Jan 3, 2025
d334fb4
add no pretokenizer bench
ArthurZucker Jan 3, 2025
a3ed0c3
push latest changes
ArthurZucker Jan 3, 2025
7c9e534
updates
ArthurZucker Jan 4, 2025
ee18ba9
nits
ArthurZucker Jan 4, 2025
4b63a7a
update serialization to support initializing from BPE
ArthurZucker Jan 6, 2025
a7baf1b
nits
ArthurZucker Jan 6, 2025
224f432
current state
ArthurZucker Jan 21, 2025
7f4dc95
updates
ArthurZucker Feb 10, 2025
7d28132
Merge branch 'main' of github.com:huggingface/tokenizers into backtrack
ArthurZucker Feb 10, 2025
bc0b3fb
fix merge issue
ArthurZucker Feb 10, 2025
f27365d
update deserialization!
ArthurZucker Feb 10, 2025
a387594
nits
ArthurZucker Feb 11, 2025
9707851
remove on of the trainers
ArthurZucker Feb 12, 2025
1073896
add a trait for training, but actually just manually converting when …
ArthurZucker Feb 12, 2025
28ee5ed
nit
ArthurZucker Feb 12, 2025
fa15b75
Merge branch 'backtrack' of github.com:huggingface/tokenizers into ba…
ArthurZucker Feb 12, 2025
7b3d09e
clean
ArthurZucker Feb 12, 2025
028f7c3
build
ArthurZucker Feb 12, 2025
31b58b0
update
ArthurZucker Feb 16, 2025
0082bc0
fmt
ArthurZucker Feb 16, 2025
18b9f5a
at least build, but this is still shit
ArthurZucker Feb 16, 2025
b9087a7
happy now?
ArthurZucker Feb 16, 2025
be1129f
cleanup
ArthurZucker Feb 16, 2025
f6054f6
fix some of the bugs: forgot about the byte level encoding !
ArthurZucker Feb 16, 2025
9b9ea69
fix tests that were not building!
ArthurZucker Feb 16, 2025
2ca9329
update test
ArthurZucker Feb 16, 2025
8e59945
current state
ArthurZucker Feb 16, 2025
b080b34
removed prints for clarity
ArthurZucker Feb 16, 2025
fe29aca
small fixes
ArthurZucker Feb 16, 2025
47070c1
fix the bench
ArthurZucker Feb 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions bindings/python/benches/test_tiktoken.py
Original file line number Diff line number Diff line change
@@ -63,9 +63,11 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document
out = enc.encode("This is a test")

hf_enc = Tokenizer.from_pretrained(model)
hf_enc.pre_tokenizer = None # apparently backtracking does not work without this
out2 = hf_enc.encode("This is a test", add_special_tokens=False).ids

assert out == out2, "sanity check"
print([hf_enc.decode([k]) for k in out2])
print([hf_enc.decode([k]) for k in out])
assert out == out2, f"sanity check {out} == {out2}, {hf_enc.decode(out)} == {hf_enc.decode(out2)}"

start = time.perf_counter_ns()
enc.encode_ordinary_batch(documents, num_threads=num_threads)
@@ -74,9 +76,8 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document
readable_size, unit = format_byte_size(num_bytes / (end - start) * 1e9)
print(f"tiktoken \t{readable_size} / s")


start = time.perf_counter_ns()
hf_enc.encode_batch_fast(documents)
hf_enc.encode_batch(documents)
end = time.perf_counter_ns()
readable_size, unit = format_byte_size(num_bytes / (end - start) * 1e9)
print(f"huggingface \t{readable_size} / s")
146 changes: 146 additions & 0 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::{Deserialize, Serialize};
use tk::models::backtracking_bpe::{BacktrackingBpe, BacktrackingBpeBuilder};
use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
use tk::models::unigram::Unigram;
use tk::models::wordlevel::WordLevel;
@@ -39,6 +40,10 @@ impl PyModel {
.into_pyobject(py)?
.into_any()
.into(),
ModelWrapper::BacktrackingBpe(_) => Py::new(py, (PyBacktrackingBpe {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?
.into_pyobject(py)?
.into_any()
@@ -572,6 +577,147 @@ impl PyBPE {
}
}

#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BacktrackingBpe")]
struct PyBacktrackingBpe {}

impl PyBacktrackingBpe {
fn with_builder(
mut builder: BacktrackingBpeBuilder,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<(Self, PyModel)> {
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: String = key.extract()?;
match key.as_ref() {
"unk_token" => {
if let Some(unk) = value.extract()? {
builder = builder.unk_token(unk);
}
}
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
};
}
}

match builder.build() {
Err(e) => Err(exceptions::PyException::new_err(format!(
"Error while initializing BPE: {}",
e
))),
Ok(bpe) => Ok((PyBacktrackingBpe {}, bpe.into())),
}
}
}

#[pymethods]
impl PyBacktrackingBpe {
#[new]
#[pyo3(
signature = (vocab=None, merges=None, **kwargs),
text_signature = "(self, vocab=None, merges=None, dropout=None, unk_token=None)")]
fn new(
py: Python<'_>,
vocab: Option<PyVocab>,
merges: Option<PyMerges>,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<(Self, PyModel)> {
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
return Err(exceptions::PyValueError::new_err(
"`vocab` and `merges` must be both specified",
));
}

let mut builder = BacktrackingBpe::builder();
if let (Some(vocab), Some(merges)) = (vocab, merges) {
match (vocab, merges) {
(PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => {
builder = builder.vocab_and_merges(vocab, merges);
}
_ => {
return Err(exceptions::PyValueError::new_err(
"`vocab` and `merges` must be both be from memory or both filenames",
));
}
}
}

PyBacktrackingBpe::with_builder(builder, kwargs)
}

/// Read a :obj:`vocab.json` and a :obj:`merges.txt` files
///
/// This method provides a way to read and parse the content of these files,
/// returning the relevant data structures. If you want to instantiate some BPE models
/// from memory, this method gives you the expected input from the standard files.
///
/// Args:
/// vocab (:obj:`str`):
/// The path to a :obj:`vocab.json` file
///
/// merges (:obj:`str`):
/// The path to a :obj:`merges.txt` file
///
/// Returns:
/// A :obj:`Tuple` with the vocab and the merges:
/// The vocabulary and merges loaded into memory
#[staticmethod]
#[pyo3(text_signature = "(self, vocab, merges)")]
fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> {
BacktrackingBpe::read_file(vocab, merges).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while reading vocab & merges files: {}",
e
))
})
}

/// Instantiate a BPE model from the given files.
///
/// This method is roughly equivalent to doing::
///
/// vocab, merges = BPE.read_file(vocab_filename, merges_filename)
/// bpe = BPE(vocab, merges)
///
/// If you don't need to keep the :obj:`vocab, merges` values lying around,
/// this method is more optimized than manually calling
/// :meth:`~tokenizers.models.BPE.read_file` to initialize a :class:`~tokenizers.models.BPE`
///
/// Args:
/// vocab (:obj:`str`):
/// The path to a :obj:`vocab.json` file
///
/// merges (:obj:`str`):
/// The path to a :obj:`merges.txt` file
///
/// Returns:
/// :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files
#[classmethod]
#[pyo3(signature = (vocab, merges, **kwargs))]
#[pyo3(text_signature = "(cls, vocab, merge, **kwargs)")]
fn from_file(
_cls: &Bound<'_, PyType>,
py: Python,
vocab: &str,
merges: &str,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<Self>> {
let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading BPE files: {}", e))
})?;
Py::new(
py,
PyBacktrackingBpe::new(
py,
Some(PyVocab::Vocab(vocab)),
Some(PyMerges::Merges(merges)),
kwargs,
)?,
)
}
}

/// An implementation of the WordPiece algorithm
///
/// Args:
2 changes: 2 additions & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
@@ -68,6 +68,8 @@ fancy-regex = { version = "0.14", optional = true}
getrandom = { version = "0.2.10" }
esaxx-rs = { version = "0.1.10", default-features = false, features=[]}
monostate = "0.1.12"
fnv = "1.0.7"
aneubeck-daachorse = "1.1.1"

[features]
default = ["progressbar", "onig", "esaxx_fast"]
44 changes: 35 additions & 9 deletions tokenizers/benches/llama3.rs
Original file line number Diff line number Diff line change
@@ -2,26 +2,51 @@
extern crate criterion;

use criterion::{Criterion, Throughput};
use itertools::Itertools;
use tokenizers::models::backtracking_bpe;
use tokenizers::PreTokenizerWrapper;
use tokenizers::Tokenizer;

pub fn llama3(c: &mut Criterion) {
let data = std::fs::read_to_string("data/big.txt").unwrap();
let mut group = c.benchmark_group("llama3-encode");
group.throughput(Throughput::Bytes(data.bytes().len() as u64));
group.bench_function("llama3-offsets", |b| {
let tokenizer =
Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap();

group.bench_function("llama3-backtracking", |b| {
let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});

group.bench_function("llama3-backtracking-no-pretok", |b| {
let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
tokenizer.with_pre_tokenizer(None::<PreTokenizerWrapper>);
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch_char_offsets(criterion::black_box(data.clone()), add_special_tokens)
.encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});
group.bench_function("llama3-nooffsets", |b| {
let tokenizer =
Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap();

group.bench_function("llama3-encode_batch_fast", |b| {
let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});
group.bench_function("llama3-encode_batch", |b| {
let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
@@ -30,13 +55,14 @@ pub fn llama3(c: &mut Criterion) {
.unwrap()
})
});

group.finish();
}

criterion_group! {
name = bert_benches;
name = llama;
config = Criterion::default().sample_size(10);
targets = llama3
}

criterion_main!(bert_benches);
criterion_main!(llama);
49 changes: 49 additions & 0 deletions tokenizers/src/models/backtracking_bpe/backtracking_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use super::bitfield::BitField;

/// This can be thought of as a lazy variation of the dynamic programming approach.
/// It only computes those states which have to be visited in order to compute the tokenization
/// for a given input text.
/// It keeps track of visited states in a bitfield and only remembers the tokenization
/// of the currently processed dynamic programming state.
///
/// The biggest downside of this approach is that the search for the longest leftmost match (the firt token?)
/// has to be reset at every (backtracking) step which is still a net win in practice compared to other approaches.
#[derive(Clone, PartialEq)]
pub struct BacktrackState<'a> {
pub(crate) text: &'a [u8],
pub(crate) tokens: Vec<u32>, // len of the tezt / 3
pub(crate) next_token: Option<u32>, // bpe.next_match(text) wich is longest_searcher.leftmost_find_iter(text)'s first match value
pub(crate) pos: usize, // current pos in the text?
pub(crate) bitfield: BitField, // keeps track of token boundaries? keeps track of all the valid tokenization positions and making the runtime linear in the input length.
}

impl<'a> BacktrackState<'a> {
pub(crate) fn new(text: &'a [u8], next_token: Option<u32>) -> Self {
Self::with_capacity(text, next_token, text.len() / 3)
}

pub(crate) fn with_capacity(text: &'a [u8], next_token: Option<u32>, cap: usize) -> Self {
Self {
text,
tokens: Vec::with_capacity(cap),
next_token,
pos: 0,
bitfield: BitField::new(text.len() + 1),
}
}
pub(crate) fn count(&self) -> usize {
self.tokens.len()
}

pub(crate) fn pos(&self) -> usize {

Check warning on line 38 in tokenizers/src/models/backtracking_bpe/backtracking_state.rs

GitHub Actions / Check it builds for Windows 32-bit (3.12)

methods `pos` and `last_token` are never used

Check warning on line 38 in tokenizers/src/models/backtracking_bpe/backtracking_state.rs

GitHub Actions / Check it builds for Windows 32-bit (3.10)

methods `pos` and `last_token` are never used

Check warning on line 38 in tokenizers/src/models/backtracking_bpe/backtracking_state.rs

GitHub Actions / Check it builds for Windows 32-bit (3.13)

methods `pos` and `last_token` are never used

Check warning on line 38 in tokenizers/src/models/backtracking_bpe/backtracking_state.rs

GitHub Actions / Check it builds for Windows 32-bit (3.9)

methods `pos` and `last_token` are never used

Check warning on line 38 in tokenizers/src/models/backtracking_bpe/backtracking_state.rs

GitHub Actions / Check it builds for Windows 32-bit (3.11)

methods `pos` and `last_token` are never used

Check warning on line 38 in tokenizers/src/models/backtracking_bpe/backtracking_state.rs

GitHub Actions / Check everything builds & tests (ubuntu-latest)

methods `pos` and `last_token` are never used

Check warning on line 38 in tokenizers/src/models/backtracking_bpe/backtracking_state.rs

GitHub Actions / Check everything builds & tests (macos-latest)

methods `pos` and `last_token` are never used
self.pos
}

pub(crate) fn last_token(&self) -> Option<u32> {
self.tokens.last().copied()
}

pub(crate) fn into_tokens(self) -> Vec<u32> {
self.tokens
}
}
57 changes: 57 additions & 0 deletions tokenizers/src/models/backtracking_bpe/bitfield.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/// Small helper to manage a bit field which supports predecessor and successor queries with a simple scan implementation.
/// This is sufficient for our use case, since two one bits will be at most 128 bits apart.
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct BitField {
bitfield: Vec<u64>,
}

impl BitField {
/// All bits are initialized to 1.
pub(crate) fn new(bits: usize) -> Self {
Self {
bitfield: vec![u64::MAX; (bits + 63) / 64],
}
}

pub(crate) fn is_set(&self, bit: usize) -> bool {
let (word, bit) = (bit / 64, bit % 64);
self.bitfield[word] & (1 << bit) != 0
}

pub(crate) fn clear(&mut self, bit: usize) {
let (word, bit) = (bit / 64, bit % 64);
self.bitfield[word] &= !(1 << bit);
}

pub(crate) fn successor(&self, bit: usize) -> usize {
let (mut word_idx, bit_idx) = (bit / 64, bit % 64);
let word = self.bitfield[word_idx] >> bit_idx;
if word != 0 {
word.trailing_zeros() as usize + bit
} else {
loop {
word_idx += 1;
let word = self.bitfield[word_idx];
if word != 0 {
break word.trailing_zeros() as usize + word_idx * 64;
}
}
}
}

pub(crate) fn predecessor(&self, bit: usize) -> usize {
let (mut word_idx, bit_idx) = (bit / 64, bit % 64);
let word = self.bitfield[word_idx] << (63 - bit_idx);
if word != 0 {
bit - word.leading_zeros() as usize
} else {
loop {
word_idx -= 1;
let word = self.bitfield[word_idx];
if word != 0 {
break word_idx * 64 + 63 - word.leading_zeros() as usize;
}
}
}
}
}
6 changes: 6 additions & 0 deletions tokenizers/src/models/backtracking_bpe/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
mod backtracking_state;
mod bitfield;
mod model;
mod serialization;

pub use model::*;
Loading