|
66 | 66 | ) |
67 | 67 | from transformers.utils.hub import _CACHED_NO_EXIST, TRANSFORMERS_CACHE, extract_commit_hash, try_to_load_from_cache |
68 | 68 |
|
69 | | -from open_instruct.utils import hf_whoami |
| 69 | +from open_instruct.utils import hf_whoami, max_num_processes |
70 | 70 |
|
71 | 71 |
|
72 | 72 | # ---------------------------------------------------------------------------- |
@@ -1379,16 +1379,25 @@ def __post_init__(self): |
1379 | 1379 | # if the file exists locally, use the local file |
1380 | 1380 | if os.path.exists(self.dataset_name) and self.dataset_name.endswith(".jsonl"): |
1381 | 1381 | assert self.dataset_split == "train", "Only train split is supported for local jsonl files." |
1382 | | - self.dataset = load_dataset("json", data_files=self.dataset_name, split=self.dataset_split) |
| 1382 | + self.dataset = load_dataset( |
| 1383 | + "json", data_files=self.dataset_name, split=self.dataset_split, num_proc=max_num_processes() |
| 1384 | + ) |
1383 | 1385 | elif os.path.exists(self.dataset_name) and self.dataset_name.endswith(".parquet"): |
1384 | 1386 | assert self.dataset_split == "train", "Only train split is supported for local parquet files." |
1385 | | - self.dataset = load_dataset("parquet", data_files=self.dataset_name, split=self.dataset_split) |
| 1387 | + self.dataset = load_dataset( |
| 1388 | + "parquet", data_files=self.dataset_name, split=self.dataset_split, num_proc=max_num_processes() |
| 1389 | + ) |
1386 | 1390 | else: |
1387 | 1391 | # commit hash only works for hf datasets |
1388 | 1392 | self.dataset_commit_hash = get_commit_hash( |
1389 | 1393 | self.dataset_name, self.dataset_revision, "README.md", "dataset" |
1390 | 1394 | ) |
1391 | | - self.dataset = load_dataset(self.dataset_name, split=self.dataset_split, revision=self.dataset_revision) |
| 1395 | + self.dataset = load_dataset( |
| 1396 | + self.dataset_name, |
| 1397 | + split=self.dataset_split, |
| 1398 | + revision=self.dataset_revision, |
| 1399 | + num_proc=max_num_processes(), |
| 1400 | + ) |
1392 | 1401 | if self.dataset_range is None: |
1393 | 1402 | dataset_range = len(self.dataset) |
1394 | 1403 | self.update_range(dataset_range) |
@@ -1512,7 +1521,12 @@ def load_or_transform_dataset( |
1512 | 1521 | print("dataset_skip_cache is True, so we will not load the dataset from cache") |
1513 | 1522 | else: |
1514 | 1523 | # Use the split from the first dataset config as default |
1515 | | - return load_dataset(repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=self.config_hash) |
| 1524 | + return load_dataset( |
| 1525 | + repo_name, |
| 1526 | + split=DEFAULT_SPLIT_FOR_CACHED_DATASET, |
| 1527 | + revision=self.config_hash, |
| 1528 | + num_proc=max_num_processes(), |
| 1529 | + ) |
1516 | 1530 |
|
1517 | 1531 | print(f"Cache not found, transforming datasets...") |
1518 | 1532 |
|
@@ -1565,7 +1579,9 @@ def load_or_transform_dataset( |
1565 | 1579 |
|
1566 | 1580 | # NOTE: Load the dataset again to make sure it's downloaded to the HF cache |
1567 | 1581 | print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{self.config_hash}") |
1568 | | - return load_dataset(repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=self.config_hash) |
| 1582 | + return load_dataset( |
| 1583 | + repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=self.config_hash, num_proc=max_num_processes() |
| 1584 | + ) |
1569 | 1585 |
|
1570 | 1586 |
|
1571 | 1587 | class LocalDatasetTransformationCache: |
@@ -1931,7 +1947,9 @@ def test_get_cached_dataset_tulu_sft(): |
1931 | 1947 | dataset_skip_cache=True, |
1932 | 1948 | ) |
1933 | 1949 |
|
1934 | | - gold_tokenized_dataset = load_dataset("allenai/dataset-mix-cached", split="train", revision="61ac38e052") |
| 1950 | + gold_tokenized_dataset = load_dataset( |
| 1951 | + "allenai/dataset-mix-cached", split="train", revision="61ac38e052", num_proc=max_num_processes() |
| 1952 | + ) |
1935 | 1953 | assert len(dataset) == len(gold_tokenized_dataset) |
1936 | 1954 | for i in range(len(dataset)): |
1937 | 1955 | assert dataset[i]["input_ids"] == gold_tokenized_dataset[i]["input_ids"] |
@@ -1959,7 +1977,9 @@ def test_get_cached_dataset_tulu_preference(): |
1959 | 1977 | TOKENIZED_PREFERENCE_DATASET_KEYS, |
1960 | 1978 | dataset_skip_cache=True, |
1961 | 1979 | ) |
1962 | | - gold_tokenized_dataset = load_dataset("allenai/dataset-mix-cached", split="train", revision="9415479293") |
| 1980 | + gold_tokenized_dataset = load_dataset( |
| 1981 | + "allenai/dataset-mix-cached", split="train", revision="9415479293", num_proc=max_num_processes() |
| 1982 | + ) |
1963 | 1983 | assert len(dataset) == len(gold_tokenized_dataset) |
1964 | 1984 | for i in range(len(dataset)): |
1965 | 1985 | assert dataset[i]["chosen_input_ids"] == gold_tokenized_dataset[i]["chosen_input_ids"] |
@@ -1987,7 +2007,9 @@ def test_get_cached_dataset_tulu_rlvr(): |
1987 | 2007 | transform_fn_args, |
1988 | 2008 | dataset_skip_cache=True, |
1989 | 2009 | ) |
1990 | | - gold_tokenized_dataset = load_dataset("allenai/dataset-mix-cached", split="train", revision="0ff0043e56") |
| 2010 | + gold_tokenized_dataset = load_dataset( |
| 2011 | + "allenai/dataset-mix-cached", split="train", revision="0ff0043e56", num_proc=max_num_processes() |
| 2012 | + ) |
1991 | 2013 | assert len(dataset) == len(gold_tokenized_dataset) |
1992 | 2014 | for i in range(len(dataset)): |
1993 | 2015 | assert dataset[i][INPUT_IDS_PROMPT_KEY] == gold_tokenized_dataset[i][INPUT_IDS_PROMPT_KEY] |
|
0 commit comments