diff --git a/prepare.py b/prepare.py index 06bea9165..1513d1beb 100644 --- a/prepare.py +++ b/prepare.py @@ -55,11 +55,17 @@ # --------------------------------------------------------------------------- def download_single_shard(index): - """Download one parquet shard with retries. Returns True on success.""" filename = f"shard_{index:05d}.parquet" filepath = os.path.join(DATA_DIR, filename) + + # 1. Enhanced integrity check: verify metadata if the file already exists locally if os.path.exists(filepath): - return True + try: + pq.read_metadata(filepath) # Semantic validation: ensure Parquet structural integrity + return True + except Exception: + print(f" Detected corrupted file {filename}, re-downloading...") + os.remove(filepath) url = f"{BASE_URL}/{filename}" max_attempts = 5 @@ -67,24 +73,37 @@ def download_single_shard(index): try: response = requests.get(url, stream=True, timeout=30) response.raise_for_status() + + # Retrieve expected file size from the HTTP Content-Length header + expected_size = int(response.headers.get('content-length', 0)) + temp_path = filepath + ".tmp" with open(temp_path, "wb") as f: for chunk in response.iter_content(chunk_size=1024 * 1024): if chunk: f.write(chunk) + + # 2. Physical validation: verify byte count consistency + actual_size = os.path.getsize(temp_path) + if expected_size > 0 and actual_size < expected_size: + raise IOError(f"Size mismatch: expected {expected_size}, got {actual_size}") + + # 3. Semantic validation: confirm the file adheres to the Parquet specification + try: + pq.read_metadata(temp_path) + except Exception as e: + raise IOError(f"Invalid Parquet metadata: {e}") + os.rename(temp_path, filepath) - print(f" Downloaded {filename}") + print(f" Downloaded and verified {filename}") return True + except (requests.RequestException, IOError) as e: - print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}") - for path in [filepath + ".tmp", filepath]: - if os.path.exists(path): - try: - os.remove(path) - except OSError: - pass - if attempt < max_attempts: - time.sleep(2 ** attempt) + print(f" Attempt {attempt}/{max_attempts} failed: {e}") + if os.path.exists(temp_path): + os.remove(temp_path) + time.sleep(2 ** attempt) + return False @@ -121,7 +140,6 @@ def list_parquet_files(): files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp")) return [os.path.join(DATA_DIR, f) for f in files] - def text_iterator(max_chars=1_000_000_000, doc_cap=10_000): """Yield documents from training split (all shards except pinned val shard).""" parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)] @@ -136,6 +154,7 @@ def text_iterator(max_chars=1_000_000_000, doc_cap=10_000): yield doc if nchars >= max_chars: return + def train_tokenizer():