Skip to content

Commit 1a2b79a

Browse files
Sets the num_proc argument for all calls to load_dataset in the repo. (#1128)
* Set load_parallel * Cleaned up PR.
1 parent a13c683 commit 1a2b79a

File tree

74 files changed

+227
-140
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+227
-140
lines changed

decontamination/index.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from elasticsearch import Elasticsearch, helpers
77
from tqdm import tqdm
88

9+
import open_instruct.utils as open_instruct_utils
10+
911

1012
def create_text_index(es, index_name):
1113
mappings = {
@@ -45,7 +47,7 @@ def create_vector_index(es, index_name):
4547

4648

4749
def read_dataset(dataset_name, split, messages_field, query_filter, query_field):
48-
dataset = load_dataset(dataset_name, split=split)
50+
dataset = load_dataset(dataset_name, split=split, num_proc=open_instruct_utils.max_num_processes())
4951
data_to_index = []
5052

5153
query_filter_key, query_filter_value = query_filter.split(":")

decontamination/search.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from tqdm import tqdm
1212
from transformers import AutoModel, AutoTokenizer
1313

14+
import open_instruct.utils as open_instruct_utils
15+
1416
SPACY_MODEL = spacy.load("en_core_web_lg")
1517

1618

@@ -286,14 +288,14 @@ def main():
286288
for dataset, subset, split, fields, limit in eval_sets:
287289
print(f"Querying {index_name} for {dataset}.")
288290
try:
289-
query_dataset = list(load_dataset(dataset, subset, split=split))[:limit]
291+
query_dataset = list(load_dataset(dataset, subset, split=split, num_proc=open_instruct_utils.max_num_processes()))[:limit]
290292
except ValueError:
291293
query_dataset = []
292294
if args.subset is None:
293295
# Dataset has multiple subsets. We want to concatenate all of them.
294296
from datasets import get_dataset_config_names
295297
for subset in get_dataset_config_names(dataset):
296-
query_dataset.extend(list(load_dataset(dataset, subset, split=split))[:limit])
298+
query_dataset.extend(list(load_dataset(dataset, subset, split=split, num_proc=open_instruct_utils.max_num_processes()))[:limit])
297299
else:
298300
raise
299301

@@ -337,7 +339,7 @@ def main():
337339
for dataset_name, contaminated_ids in zip(dataset_names, all_index_contaminated_ids):
338340
print(f"Decontaminating {dataset_name}")
339341
# Assuming dataset has no subsets and we want the train split.
340-
train_dataset = load_dataset(dataset_name, split="train")
342+
train_dataset = load_dataset(dataset_name, split="train", num_proc=open_instruct_utils.max_num_processes())
341343
decontaminated_dataset = []
342344
num_kept = 0
343345
num_total = 0

open_instruct/code_utils/test_code_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import datasets
88
import parameterized
99

10+
import open_instruct.utils as open_instruct_utils
1011
from open_instruct.code_utils import code_utils
1112

1213
SIMPLE_PROGRAM = "a = 1"
@@ -45,7 +46,9 @@ def test_all_fail_or_timeout(self):
4546

4647
def test_tiger_lab_acecode_sample(self):
4748
"""Tests the script against an actual AceCode record."""
48-
ds = datasets.load_dataset("TIGER-Lab/AceCode-87K", split="train")
49+
ds = datasets.load_dataset(
50+
"TIGER-Lab/AceCode-87K", split="train", num_proc=open_instruct_utils.max_num_processes()
51+
)
4952

5053
# Choose the same sample index used in the original snippet.
5154
i = 1

open_instruct/dataset_transformation.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
)
6767
from transformers.utils.hub import _CACHED_NO_EXIST, TRANSFORMERS_CACHE, extract_commit_hash, try_to_load_from_cache
6868

69-
from open_instruct.utils import hf_whoami
69+
from open_instruct.utils import hf_whoami, max_num_processes
7070

7171

7272
# ----------------------------------------------------------------------------
@@ -1379,16 +1379,25 @@ def __post_init__(self):
13791379
# if the file exists locally, use the local file
13801380
if os.path.exists(self.dataset_name) and self.dataset_name.endswith(".jsonl"):
13811381
assert self.dataset_split == "train", "Only train split is supported for local jsonl files."
1382-
self.dataset = load_dataset("json", data_files=self.dataset_name, split=self.dataset_split)
1382+
self.dataset = load_dataset(
1383+
"json", data_files=self.dataset_name, split=self.dataset_split, num_proc=max_num_processes()
1384+
)
13831385
elif os.path.exists(self.dataset_name) and self.dataset_name.endswith(".parquet"):
13841386
assert self.dataset_split == "train", "Only train split is supported for local parquet files."
1385-
self.dataset = load_dataset("parquet", data_files=self.dataset_name, split=self.dataset_split)
1387+
self.dataset = load_dataset(
1388+
"parquet", data_files=self.dataset_name, split=self.dataset_split, num_proc=max_num_processes()
1389+
)
13861390
else:
13871391
# commit hash only works for hf datasets
13881392
self.dataset_commit_hash = get_commit_hash(
13891393
self.dataset_name, self.dataset_revision, "README.md", "dataset"
13901394
)
1391-
self.dataset = load_dataset(self.dataset_name, split=self.dataset_split, revision=self.dataset_revision)
1395+
self.dataset = load_dataset(
1396+
self.dataset_name,
1397+
split=self.dataset_split,
1398+
revision=self.dataset_revision,
1399+
num_proc=max_num_processes(),
1400+
)
13921401
if self.dataset_range is None:
13931402
dataset_range = len(self.dataset)
13941403
self.update_range(dataset_range)
@@ -1512,7 +1521,12 @@ def load_or_transform_dataset(
15121521
print("dataset_skip_cache is True, so we will not load the dataset from cache")
15131522
else:
15141523
# Use the split from the first dataset config as default
1515-
return load_dataset(repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=self.config_hash)
1524+
return load_dataset(
1525+
repo_name,
1526+
split=DEFAULT_SPLIT_FOR_CACHED_DATASET,
1527+
revision=self.config_hash,
1528+
num_proc=max_num_processes(),
1529+
)
15161530

15171531
print(f"Cache not found, transforming datasets...")
15181532

@@ -1565,7 +1579,9 @@ def load_or_transform_dataset(
15651579

15661580
# NOTE: Load the dataset again to make sure it's downloaded to the HF cache
15671581
print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{self.config_hash}")
1568-
return load_dataset(repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=self.config_hash)
1582+
return load_dataset(
1583+
repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=self.config_hash, num_proc=max_num_processes()
1584+
)
15691585

15701586

15711587
class LocalDatasetTransformationCache:
@@ -1931,7 +1947,9 @@ def test_get_cached_dataset_tulu_sft():
19311947
dataset_skip_cache=True,
19321948
)
19331949

1934-
gold_tokenized_dataset = load_dataset("allenai/dataset-mix-cached", split="train", revision="61ac38e052")
1950+
gold_tokenized_dataset = load_dataset(
1951+
"allenai/dataset-mix-cached", split="train", revision="61ac38e052", num_proc=max_num_processes()
1952+
)
19351953
assert len(dataset) == len(gold_tokenized_dataset)
19361954
for i in range(len(dataset)):
19371955
assert dataset[i]["input_ids"] == gold_tokenized_dataset[i]["input_ids"]
@@ -1959,7 +1977,9 @@ def test_get_cached_dataset_tulu_preference():
19591977
TOKENIZED_PREFERENCE_DATASET_KEYS,
19601978
dataset_skip_cache=True,
19611979
)
1962-
gold_tokenized_dataset = load_dataset("allenai/dataset-mix-cached", split="train", revision="9415479293")
1980+
gold_tokenized_dataset = load_dataset(
1981+
"allenai/dataset-mix-cached", split="train", revision="9415479293", num_proc=max_num_processes()
1982+
)
19631983
assert len(dataset) == len(gold_tokenized_dataset)
19641984
for i in range(len(dataset)):
19651985
assert dataset[i]["chosen_input_ids"] == gold_tokenized_dataset[i]["chosen_input_ids"]
@@ -1987,7 +2007,9 @@ def test_get_cached_dataset_tulu_rlvr():
19872007
transform_fn_args,
19882008
dataset_skip_cache=True,
19892009
)
1990-
gold_tokenized_dataset = load_dataset("allenai/dataset-mix-cached", split="train", revision="0ff0043e56")
2010+
gold_tokenized_dataset = load_dataset(
2011+
"allenai/dataset-mix-cached", split="train", revision="0ff0043e56", num_proc=max_num_processes()
2012+
)
19912013
assert len(dataset) == len(gold_tokenized_dataset)
19922014
for i in range(len(dataset)):
19932015
assert dataset[i][INPUT_IDS_PROMPT_KEY] == gold_tokenized_dataset[i][INPUT_IDS_PROMPT_KEY]

open_instruct/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@
7575
DataClassType = NewType("DataClassType", Any)
7676

7777

78+
def max_num_processes() -> int:
79+
"""Returns a reasonable default number of processes to run for multiprocessing."""
80+
if hasattr(os, "sched_getaffinity"):
81+
return len(os.sched_getaffinity(0))
82+
else:
83+
return os.cpu_count() or 1
84+
85+
7886
def repeat_each(seq, k):
7987
"""Repeat each element in a sequence k times."""
8088
return [item for item in seq for _ in range(k)]
@@ -315,13 +323,13 @@ def get_datasets(
315323
for split in splits:
316324
# if dataset ends with .json or .jsonl, load from file
317325
if ds.endswith(".json") or ds.endswith(".jsonl"):
318-
dataset = load_dataset("json", data_files=ds, split=split)
326+
dataset = load_dataset("json", data_files=ds, split=split, num_proc=max_num_processes())
319327
elif ds.endswith(".parquet"):
320-
dataset = load_dataset("parquet", data_files=ds, split=split)
328+
dataset = load_dataset("parquet", data_files=ds, split=split, num_proc=max_num_processes())
321329
else:
322330
try:
323331
# Try first if dataset on a Hub repo
324-
dataset = load_dataset(ds, ds_config, split=split)
332+
dataset = load_dataset(ds, ds_config, split=split, num_proc=max_num_processes())
325333
except DatasetGenerationError:
326334
# If not, check local dataset
327335
dataset = load_from_disk(os.path.join(ds, split))
@@ -529,11 +537,11 @@ def combine_dataset(
529537
for (ds, frac_or_samples), ds_config, split in zip(dataset_mixer.items(), configs, splits):
530538
# if dataset ends with .json or .jsonl, load from file
531539
if ds.endswith(".json") or ds.endswith(".jsonl"):
532-
dataset = load_dataset("json", data_files=ds, split=split)
540+
dataset = load_dataset("json", data_files=ds, split=split, num_proc=max_num_processes())
533541
else:
534542
try:
535543
# Try first if dataset on a Hub repo
536-
dataset = load_dataset(ds, ds_config, split=split)
544+
dataset = load_dataset(ds, ds_config, split=split, num_proc=max_num_processes())
537545
except DatasetGenerationError:
538546
# If not, check local dataset
539547
dataset = load_from_disk(os.path.join(ds, split))

quantize/quantize_autogptq_wikitext.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
from datasets import load_dataset
1919
from transformers import AutoTokenizer
2020

21+
import open_instruct.utils as open_instruct_utils
22+
2123

2224
def get_wikitext2(nsamples, seed, seqlen, model):
23-
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
24-
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
25+
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train", num_proc=open_instruct_utils.max_num_processes())
26+
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", num_proc=open_instruct_utils.max_num_processes())
2527

2628
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
2729
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")

scripts/create_ground_truth_data.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import random
55

66
from datasets import Dataset, load_dataset
7+
import open_instruct.utils as open_instruct_utils
78
from tqdm import tqdm
89

910
from open_instruct.math_utils import last_boxed_only_string, remove_boxed
@@ -86,7 +87,7 @@
8687
for sample in GSM8K_EXEMPLARS:
8788
gsm8k_prompt += f"Question: {sample['question'].strip()}\nAnswer:{sample['cot_answer'].strip()}\n\n"
8889

89-
gsm8k_dataset = load_dataset("gsm8k", "main", split="train")
90+
gsm8k_dataset = load_dataset("gsm8k", "main", split="train", num_proc=open_instruct_utils.max_num_processes())
9091
new_data = []
9192
for sample in gsm8k_dataset:
9293
answer = sample["answer"].split("####")[-1].strip()
@@ -97,7 +98,7 @@
9798
})
9899

99100
# also make a test split for eval
100-
gsm8k_dataset = load_dataset("gsm8k", "main", split="test")
101+
gsm8k_dataset = load_dataset("gsm8k", "main", split="test", num_proc=open_instruct_utils.max_num_processes())
101102
test_data = []
102103
for sample in gsm8k_dataset:
103104
answer = sample["answer"].split("####")[-1].strip()
@@ -111,7 +112,7 @@
111112
math_prompt = ""
112113
for sample in MATH_EXAMPLARS:
113114
math_prompt += f"Question: {sample['question'].strip()}\nAnswer:{sample['cot_answer'].strip()}\n\n"
114-
math_dataset = load_dataset("lighteval/MATH", "all", split="train")
115+
math_dataset = load_dataset("lighteval/MATH", "all", split="train", num_proc=open_instruct_utils.max_num_processes())
115116
for sample in math_dataset:
116117
# same code used to extract answer for eval
117118
answer = remove_boxed(last_boxed_only_string(sample["solution"]))
@@ -132,7 +133,7 @@
132133
# dataset.push_to_hub("ai2-adapt-dev/gsm8k_math_ground_truth")
133134

134135
# # alternate dataset: metamathqa!
135-
# metamathqa_dataset = load_dataset("meta-math/MetaMathQA", "main", split="train")
136+
# metamathqa_dataset = load_dataset("meta-math/MetaMathQA", "main", split="train", num_proc=open_instruct_utils.max_num_processes())
136137
# # let's re-use the MATH prompt.
137138
# new_data = []
138139
# def extract_answer(text):
@@ -158,7 +159,7 @@
158159
# dataset.push_to_hub("ai2-adapt-dev/metamathqa_ground_truth")
159160

160161
# alternate dataset: numina-tir
161-
metamathqa_dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train")
162+
metamathqa_dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train", num_proc=open_instruct_utils.max_num_processes())
162163
# let's re-use the MATH prompt.
163164
new_data = []
164165
def find_last_outermost_boxed(string):
@@ -209,7 +210,7 @@ def find_last_outermost_boxed(string):
209210
dataset.push_to_hub("ai2-adapt-dev/numinamath_tir_ground_truth_one_turn")
210211

211212
# alternate dataset: numina-cot (much, much larger)
212-
metamathqa_dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train")
213+
metamathqa_dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train", num_proc=open_instruct_utils.max_num_processes())
213214
# let's re-use the MATH prompt.
214215
new_data = []
215216
for sample in tqdm(metamathqa_dataset):

scripts/data/azure_batch/process_azure_batch_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def process_batch_results(
277277
max_rows: Optional[int] = None,
278278
):
279279
# Load the original dataset first so we can look up failed prompts
280-
original_ds = datasets.load_dataset(input_dataset, split=split)
280+
original_ds = datasets.load_dataset(input_dataset, split=split, num_proc=max_num_processes())
281281
id_lookup = {row["id"]: row for row in original_ds}
282282

283283
all_batch_results = {}

scripts/data/azure_batch/regenerate_dataset_completions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def main(sample_limit: int | None = None,
188188
os.makedirs(f"{current_dir}/batch_files", exist_ok=True)
189189

190190
print(f"Loading dataset {input_dataset_name} with split {split}")
191-
input_dataset = datasets.load_dataset(input_dataset_name, split=split)
191+
input_dataset = datasets.load_dataset(input_dataset_name, split=split, num_proc=max_num_processes())
192192

193193
# First get all unique IDs
194194
print(f'Processing dataset with {len(input_dataset)} rows')

scripts/data/build_hardcoded.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import partial
44

55
from datasets import DatasetDict, load_dataset
6+
import open_instruct.utils as open_instruct_utils
67
from huggingface_hub import HfApi
78

89
from open_instruct import logger_utils
@@ -281,7 +282,7 @@ def main():
281282
# --- Load Source Dataset ---
282283
try:
283284
logger.info(f"Loading source dataset '{args.source_repo}'...")
284-
original_dataset = load_dataset(args.source_repo)
285+
original_dataset = load_dataset(args.source_repo, num_proc=open_instruct_utils.max_num_processes())
285286
logger.info(f"Dataset loaded successfully. Splits: {list(original_dataset.keys())}")
286287
except Exception as e:
287288
logger.error(f"Failed to load source dataset '{args.source_repo}': {e}")

0 commit comments

Comments
 (0)