Skip to content

Commit

Permalink
Add possessive quantifiers to avoid catastrophic backtracking (#258)
Browse files Browse the repository at this point in the history
Fixes the crash in #245 by
prohibiting the regex engine from backtracking catastrophically via
[possessive
quantifiers](https://www.regular-expressions.info/possessive.html).

<img width="400" alt="image"
src="https://github.com/openai/tiktoken/assets/1841944/ed341153-4cf4-4c1c-93d6-3f5e32133569">

Interestingly these possesives make the encoding a lot faster again in
`fancy-regex`.

Before this change (but with large byte pair merge PR cherry-picked):
```
num_threads: 1, num_bytes: 98379553
tiktoken 	11,946,036 bytes / s
tiktoken 	11,961,343 bytes / s
tiktoken 	11,995,846 bytes / s
tiktoken 	11,951,263 bytes / s
tiktoken 	11,983,405 bytes / s
```
Same, with these changes applied:
```
num_threads: 1, num_bytes: 98379553
tiktoken 	14,511,827 bytes / s
tiktoken 	14,638,134 bytes / s
tiktoken 	14,644,029 bytes / s
tiktoken 	14,729,030 bytes / s
tiktoken 	14,666,903 bytes / s
```
Updating the regex libs makes it a tiny bit faster still:
```
num_threads: 1, num_bytes: 98379553
tiktoken 	14,485,590 bytes / s
tiktoken 	14,854,049 bytes / s
tiktoken 	14,891,086 bytes / s
tiktoken 	14,843,007 bytes / s
tiktoken 	14,874,520 bytes / s
```

This is almost 2x faster than [before any of the
optimizations](#234).

-------

Opened an issue for increasing the [default backtrack
limit](https://github.com/fancy-regex/fancy-regex/blob/bf2c807447f72ee20ae839e0f8cb3a06fc79982c/src/lib.rs#L407),
see: fancy-regex/fancy-regex#134, but it
shouldn't be necessary here anymore.

---------

Co-authored-by: Lőrinc <[email protected]>
  • Loading branch information
l0rinc and Lőrinc authored Oct 3, 2024
1 parent c0ba74c commit 9f7f69d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ crate-type = ["cdylib"]
pyo3 = { version = "0.20.0", features = ["extension-module"] }

# tiktoken dependencies
fancy-regex = "0.11.0"
regex = "1.8.3"
fancy-regex = "0.13.0"
regex = "1.10.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"
16 changes: 15 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::num::NonZeroU64;
use std::thread;

use fancy_regex::Regex;
use fancy_regex::RegexBuilder;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::pyclass;
Expand Down Expand Up @@ -417,7 +418,7 @@ impl CoreBPE {
special_tokens_encoder: HashMap<String, Rank>,
pattern: &str,
) -> PyResult<Self> {
let regex = Regex::new(pattern)
let regex = RegexBuilder::new(pattern).backtrack_limit(10_000).build()
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;

let special_regex = {
Expand Down Expand Up @@ -572,6 +573,7 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {

#[cfg(test)]
mod tests {
use fancy_regex::RegexBuilder;
use rustc_hash::FxHashMap as HashMap;

use crate::{byte_pair_split, Rank};
Expand All @@ -596,4 +598,16 @@ mod tests {
let res = byte_pair_split(b"abab", &ranks);
assert_eq!(res, vec![b"ab", b"ab"]);
}

#[test]
fn test_effect_of_backtrack_limit() {
let regex = RegexBuilder::new(r"(a|b|ab)*(?=c)")
.backtrack_limit(10)
.build()
.expect("Failed to build regex")
.clone();

let input = "ab".repeat(100) + "c";
assert!(regex.is_match(&input).is_err(), "Should throw");
}
}
16 changes: 16 additions & 0 deletions tests/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@
from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES


@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]):
enc = make_enc()
for c in ["^", "0", "a", "'s", " ", "\n"]:
print(f"Validating `{c}`")

big_value = c * 10_000
assert big_value == enc.decode(enc.encode(big_value))

big_value = " " + big_value
assert big_value == enc.decode(enc.encode(big_value))

big_value = big_value + "\n"
assert big_value == enc.decode(enc.encode(big_value))


def test_simple():
enc = tiktoken.get_encoding("gpt2")
assert enc.encode("hello world") == [31373, 995]
Expand Down
18 changes: 10 additions & 8 deletions tiktoken_ext/openai_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
FIM_SUFFIX = "<|fim_suffix|>"
ENDOFPROMPT = "<|endofprompt|>"

# The pattern in the original GPT-2 release is:
# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# This is equivalent, but executes faster:
_legacy_splitter_regex = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s"""


def gpt2():
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
Expand All @@ -17,10 +22,7 @@ def gpt2():
return {
"name": "gpt2",
"explicit_n_vocab": 50257,
# The pattern in the original GPT-2 release is:
# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# This is equivalent, but executes faster:
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": _legacy_splitter_regex,
"mergeable_ranks": mergeable_ranks,
"special_tokens": {ENDOFTEXT: 50256},
}
Expand All @@ -34,7 +36,7 @@ def r50k_base():
return {
"name": "r50k_base",
"explicit_n_vocab": 50257,
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": _legacy_splitter_regex,
"mergeable_ranks": mergeable_ranks,
"special_tokens": {ENDOFTEXT: 50256},
}
Expand All @@ -48,7 +50,7 @@ def p50k_base():
return {
"name": "p50k_base",
"explicit_n_vocab": 50281,
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": _legacy_splitter_regex,
"mergeable_ranks": mergeable_ranks,
"special_tokens": {ENDOFTEXT: 50256},
}
Expand All @@ -62,7 +64,7 @@ def p50k_edit():
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
return {
"name": "p50k_edit",
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": _legacy_splitter_regex,
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
Expand All @@ -82,7 +84,7 @@ def cl100k_base():
}
return {
"name": "cl100k_base",
"pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""",
"pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""",
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
Expand Down

0 comments on commit 9f7f69d

Please sign in to comment.