Skip to content

Commit cb31cd2

Browse files
committed
Sync codebase
1 parent e35ab09 commit cb31cd2

File tree

4 files changed

+30
-8
lines changed

4 files changed

+30
-8
lines changed

.github/workflows/build_wheels.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
strategy:
3939
fail-fast: false
4040
matrix:
41-
os: [ubuntu-22.04-arm]
41+
os: [ubuntu-24.04-arm]
4242
python-version: [39, 310, 311, 312, 313]
4343

4444
steps:
@@ -55,7 +55,7 @@ jobs:
5555

5656
- uses: actions/upload-artifact@v4
5757
with:
58-
name: cibw-wheelsaarch64-${{ matrix.os }}-${{ strategy.job-index }}
58+
name: cibw-wheels-aarch64-${{ matrix.os }}-${{ strategy.job-index }}
5959
path: ./wheelhouse/*.whl
6060

6161
build_sdist:

src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ impl CoreBPE {
256256
}
257257
let end = next_special.map_or(text.len(), |m| m.start());
258258

259-
// Okay, here we go, compare this logic to _encode_ordinary_native
259+
// Okay, here we go, compare this logic to encode_ordinary
260260
for mat in regex.find_iter(&text[start..end]) {
261261
let piece = mat.unwrap().as_str().as_bytes();
262262
if let Some(token) = self.encoder.get(piece) {
@@ -398,7 +398,7 @@ impl CoreBPE {
398398
// notice all the big holes in the previous unstable token implementation)
399399
Err(_) => byte_pair_encode(&possibility, &self.encoder),
400400
// Something like the following is intriguing but incorrect:
401-
// Err(e) => self._encode_ordinary_native(unsafe {
401+
// Err(e) => self.encode_ordinary(unsafe {
402402
// std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
403403
// }),
404404
};

src/py.rs

+15-4
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,34 @@ impl CoreBPE {
6868
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
6969
py.allow_threads(|| {
7070
match std::str::from_utf8(bytes) {
71+
// Straightforward case
7172
Ok(text) => self.encode_ordinary(text),
73+
// Oops, don't actually have UTF-8. But we need to do the regex splitting in
74+
// Unicode space, so we make our best guess at where we would have splits
7275
Err(e) => {
7376
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
7477
let (tokens, last_piece_token_len) = self.encode(text, &HashSet::new());
7578
let (mut tokens, last_piece_token_len) =
7679
self._increase_last_piece_token_len(tokens, last_piece_token_len);
80+
81+
let mut unstable_bytes;
7782
if !tokens.is_empty() && last_piece_token_len > 0 {
7883
// Lop off the tokens from the last piece and run BPE on the remaining bytes
79-
// Somewhat niche, but this may not be correct if we'd have had a regex
80-
// split between the valid UTF-8 and the invalid bytes, which is why this
81-
// method is private
82-
let mut unstable_bytes = self
84+
// This likely matches what models see better, e.g. if you assume we're
85+
// dealing with truncated UTF-8 bytes.
86+
// Niche, but note this may not be correct if we'd have had a regex
87+
// split between the valid UTF-8 and the invalid bytes.
88+
unstable_bytes = self
8389
.decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
8490
.unwrap();
8591
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
8692

8793
tokens.truncate(tokens.len() - last_piece_token_len);
94+
} else {
95+
unstable_bytes = bytes[e.valid_up_to()..].to_vec();
96+
}
97+
98+
if !unstable_bytes.is_empty() {
8899
match self.encoder.get(&unstable_bytes) {
89100
Some(token) => tokens.push(*token),
90101
None => {

tests/test_encoding.py

+11
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@ def test_encode_empty():
7878
def test_encode_bytes():
7979
enc = tiktoken.get_encoding("cl100k_base")
8080
assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085]
81+
for i in range(10):
82+
bytestring = b"\x80" * i
83+
assert enc.decode_bytes(enc._encode_bytes(bytestring)) == bytestring
84+
85+
86+
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
87+
@hypothesis.given(bytestring=st.binary())
88+
@hypothesis.settings(deadline=None)
89+
def test_hyp_encode_bytes(make_enc: Callable[[], tiktoken.Encoding], bytestring: bytes):
90+
enc = make_enc()
91+
assert enc.decode_bytes(enc._encode_bytes(bytestring)) == bytestring
8192

8293

8394
def test_encode_surrogate_pairs():

0 commit comments

Comments
 (0)