6
6
import os
7
7
import tempfile
8
8
import uuid
9
+ from typing import Optional
9
10
10
11
import requests
11
12
@@ -26,7 +27,12 @@ def read_file(blobpath: str) -> bytes:
26
27
return resp .content
27
28
28
29
29
- def read_file_cached (blobpath : str ) -> bytes :
30
+ def check_hash (data : bytes , hash : str ) -> bool :
31
+ data_hash = hashlib .sha256 (data ).hexdigest ()
32
+ return data_hash == hash
33
+
34
+
35
+ def read_file_cached (blobpath : str , expected_hash : Optional [str ]= None ) -> bytes :
30
36
user_specified_cache = True
31
37
if "TIKTOKEN_CACHE_DIR" in os .environ :
32
38
cache_dir = os .environ ["TIKTOKEN_CACHE_DIR" ]
@@ -45,9 +51,20 @@ def read_file_cached(blobpath: str) -> bytes:
45
51
cache_path = os .path .join (cache_dir , cache_key )
46
52
if os .path .exists (cache_path ):
47
53
with open (cache_path , "rb" ) as f :
48
- return f .read ()
54
+ data = f .read ()
55
+ if expected_hash and not check_hash (data , expected_hash ):
56
+ raise ValueError (
57
+ f"Hash mismatch for cached data from { blobpath } (expected { expected_hash } ). "
58
+ f"Please delete the cache file at { cache_path } and try again."
59
+ )
60
+ return data
49
61
50
62
contents = read_file (blobpath )
63
+ if expected_hash and not check_hash (contents , expected_hash ):
64
+ raise ValueError (
65
+ f"Hash mismatch for data downloaded from { blobpath } (expected { expected_hash } ). "
66
+ f"This may indicate a corrupted download. Please try again."
67
+ )
51
68
52
69
try :
53
70
os .makedirs (cache_dir , exist_ok = True )
@@ -64,7 +81,7 @@ def read_file_cached(blobpath: str) -> bytes:
64
81
65
82
66
83
def data_gym_to_mergeable_bpe_ranks (
67
- vocab_bpe_file : str , encoder_json_file : str
84
+ vocab_bpe_file : str , encoder_json_file : str , vocab_bpe_hash : Optional [ str ] = None , encoder_json_hash : Optional [ str ] = None
68
85
) -> dict [bytes , int ]:
69
86
# NB: do not add caching to this function
70
87
rank_to_intbyte = [b for b in range (2 ** 8 ) if chr (b ).isprintable () and chr (b ) != " " ]
@@ -79,7 +96,7 @@ def data_gym_to_mergeable_bpe_ranks(
79
96
assert len (rank_to_intbyte ) == 2 ** 8
80
97
81
98
# vocab_bpe contains the merges along with associated ranks
82
- vocab_bpe_contents = read_file_cached (vocab_bpe_file ).decode ()
99
+ vocab_bpe_contents = read_file_cached (vocab_bpe_file , vocab_bpe_hash ).decode ()
83
100
bpe_merges = [tuple (merge_str .split ()) for merge_str in vocab_bpe_contents .split ("\n " )[1 :- 1 ]]
84
101
85
102
def decode_data_gym (value : str ) -> bytes :
@@ -96,7 +113,7 @@ def decode_data_gym(value: str) -> bytes:
96
113
# check that the encoder file matches the merges file
97
114
# this sanity check is important since tiktoken assumes that ranks are ordered the same
98
115
# as merge priority
99
- encoder_json = json .loads (read_file_cached (encoder_json_file ))
116
+ encoder_json = json .loads (read_file_cached (encoder_json_file , encoder_json_hash ))
100
117
encoder_json_loaded = {decode_data_gym (k ): v for k , v in encoder_json .items ()}
101
118
# drop these two special tokens if present, since they're not mergeable bpe tokens
102
119
encoder_json_loaded .pop (b"<|endoftext|>" , None )
@@ -118,9 +135,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
118
135
f .write (base64 .b64encode (token ) + b" " + str (rank ).encode () + b"\n " )
119
136
120
137
121
- def load_tiktoken_bpe (tiktoken_bpe_file : str ) -> dict [bytes , int ]:
138
+ def load_tiktoken_bpe (tiktoken_bpe_file : str , expected_hash : Optional [ str ] = None ) -> dict [bytes , int ]:
122
139
# NB: do not add caching to this function
123
- contents = read_file_cached (tiktoken_bpe_file )
140
+ contents = read_file_cached (tiktoken_bpe_file , expected_hash )
124
141
return {
125
142
base64 .b64decode (token ): int (rank )
126
143
for token , rank in (line .split () for line in contents .splitlines () if line )
0 commit comments