Skip to content

Commit 24d68bd

Browse files
authored
Merge branch 'main' into paplorinc/add-linearithmic-byte-pair-merge
2 parents d24b67b + db5bda9 commit 24d68bd

File tree

3 files changed

+35
-13
lines changed

3 files changed

+35
-13
lines changed

README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ If you work at OpenAI, make sure to check the internal documentation or feel fre
4242

4343
## What is BPE anyway?
4444

45-
Models don't see text like you and I, instead they see a sequence of numbers (known as tokens).
45+
Language models don't see text like you and I, instead they see a sequence of numbers (known as tokens).
4646
Byte pair encoding (BPE) is a way of converting text into tokens. It has a couple desirable
4747
properties:
4848
1) It's reversible and lossless, so you can convert tokens back into the original text
@@ -128,4 +128,3 @@ setup(
128128

129129
Then simply `pip install ./my_tiktoken_extension` and you should be able to use your
130130
custom encodings! Make sure **not** to use an editable install.
131-

tiktoken/load.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import tempfile
88
import uuid
9+
from typing import Optional
910

1011
import requests
1112

@@ -26,7 +27,12 @@ def read_file(blobpath: str) -> bytes:
2627
return resp.content
2728

2829

29-
def read_file_cached(blobpath: str) -> bytes:
30+
def check_hash(data: bytes, hash: str) -> bool:
31+
data_hash = hashlib.sha256(data).hexdigest()
32+
return data_hash == hash
33+
34+
35+
def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
3036
user_specified_cache = True
3137
if "TIKTOKEN_CACHE_DIR" in os.environ:
3238
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
@@ -45,9 +51,20 @@ def read_file_cached(blobpath: str) -> bytes:
4551
cache_path = os.path.join(cache_dir, cache_key)
4652
if os.path.exists(cache_path):
4753
with open(cache_path, "rb") as f:
48-
return f.read()
54+
data = f.read()
55+
if expected_hash and not check_hash(data, expected_hash):
56+
raise ValueError(
57+
f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). "
58+
f"Please delete the cache file at {cache_path} and try again."
59+
)
60+
return data
4961

5062
contents = read_file(blobpath)
63+
if expected_hash and not check_hash(contents, expected_hash):
64+
raise ValueError(
65+
f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). "
66+
f"This may indicate a corrupted download. Please try again."
67+
)
5168

5269
try:
5370
os.makedirs(cache_dir, exist_ok=True)
@@ -64,7 +81,7 @@ def read_file_cached(blobpath: str) -> bytes:
6481

6582

6683
def data_gym_to_mergeable_bpe_ranks(
67-
vocab_bpe_file: str, encoder_json_file: str
84+
vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None
6885
) -> dict[bytes, int]:
6986
# NB: do not add caching to this function
7087
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
@@ -79,7 +96,7 @@ def data_gym_to_mergeable_bpe_ranks(
7996
assert len(rank_to_intbyte) == 2**8
8097

8198
# vocab_bpe contains the merges along with associated ranks
82-
vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode()
99+
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
83100
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]
84101

85102
def decode_data_gym(value: str) -> bytes:
@@ -96,7 +113,7 @@ def decode_data_gym(value: str) -> bytes:
96113
# check that the encoder file matches the merges file
97114
# this sanity check is important since tiktoken assumes that ranks are ordered the same
98115
# as merge priority
99-
encoder_json = json.loads(read_file_cached(encoder_json_file))
116+
encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
100117
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
101118
# drop these two special tokens if present, since they're not mergeable bpe tokens
102119
encoder_json_loaded.pop(b"<|endoftext|>", None)
@@ -118,9 +135,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
118135
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
119136

120137

121-
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
138+
def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]:
122139
# NB: do not add caching to this function
123-
contents = read_file_cached(tiktoken_bpe_file)
140+
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
124141
return {
125142
base64.b64decode(token): int(rank)
126143
for token, rank in (line.split() for line in contents.splitlines() if line)

tiktoken_ext/openai_public.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def gpt2():
1111
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
1212
vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
1313
encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json",
14+
vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5",
15+
encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783",
1416
)
1517
return {
1618
"name": "gpt2",
@@ -23,7 +25,8 @@ def gpt2():
2325

2426
def r50k_base():
2527
mergeable_ranks = load_tiktoken_bpe(
26-
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"
28+
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken",
29+
expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930",
2730
)
2831
return {
2932
"name": "r50k_base",
@@ -36,7 +39,8 @@ def r50k_base():
3639

3740
def p50k_base():
3841
mergeable_ranks = load_tiktoken_bpe(
39-
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
42+
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
43+
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
4044
)
4145
return {
4246
"name": "p50k_base",
@@ -49,7 +53,8 @@ def p50k_base():
4953

5054
def p50k_edit():
5155
mergeable_ranks = load_tiktoken_bpe(
52-
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
56+
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
57+
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
5358
)
5459
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
5560
return {
@@ -62,7 +67,8 @@ def p50k_edit():
6267

6368
def cl100k_base():
6469
mergeable_ranks = load_tiktoken_bpe(
65-
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
70+
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
71+
expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
6672
)
6773
special_tokens = {
6874
ENDOFTEXT: 100257,

0 commit comments

Comments
 (0)