Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion decontamination/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from elasticsearch import Elasticsearch, helpers
from tqdm import tqdm

import open_instruct.utils as open_instruct_utils


def create_text_index(es, index_name):
mappings = {
Expand Down Expand Up @@ -45,7 +47,7 @@ def create_vector_index(es, index_name):


def read_dataset(dataset_name, split, messages_field, query_filter, query_field):
dataset = load_dataset(dataset_name, split=split)
dataset = load_dataset(dataset_name, split=split, num_proc=open_instruct_utils.max_num_processes())
data_to_index = []

query_filter_key, query_filter_value = query_filter.split(":")
Expand Down
8 changes: 5 additions & 3 deletions decontamination/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

import open_instruct.utils as open_instruct_utils

SPACY_MODEL = spacy.load("en_core_web_lg")


Expand Down Expand Up @@ -286,14 +288,14 @@ def main():
for dataset, subset, split, fields, limit in eval_sets:
print(f"Querying {index_name} for {dataset}.")
try:
query_dataset = list(load_dataset(dataset, subset, split=split))[:limit]
query_dataset = list(load_dataset(dataset, subset, split=split, num_proc=open_instruct_utils.max_num_processes()))[:limit]
except ValueError:
query_dataset = []
if args.subset is None:
# Dataset has multiple subsets. We want to concatenate all of them.
from datasets import get_dataset_config_names
for subset in get_dataset_config_names(dataset):
query_dataset.extend(list(load_dataset(dataset, subset, split=split))[:limit])
query_dataset.extend(list(load_dataset(dataset, subset, split=split, num_proc=open_instruct_utils.max_num_processes()))[:limit])
else:
raise

Expand Down Expand Up @@ -337,7 +339,7 @@ def main():
for dataset_name, contaminated_ids in zip(dataset_names, all_index_contaminated_ids):
print(f"Decontaminating {dataset_name}")
# Assuming dataset has no subsets and we want the train split.
train_dataset = load_dataset(dataset_name, split="train")
train_dataset = load_dataset(dataset_name, split="train", num_proc=open_instruct_utils.max_num_processes())
decontaminated_dataset = []
num_kept = 0
num_total = 0
Expand Down
5 changes: 4 additions & 1 deletion open_instruct/code_utils/test_code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import datasets
import parameterized

import open_instruct.utils as open_instruct_utils
from open_instruct.code_utils import code_utils

SIMPLE_PROGRAM = "a = 1"
Expand Down Expand Up @@ -45,7 +46,9 @@ def test_all_fail_or_timeout(self):

def test_tiger_lab_acecode_sample(self):
"""Tests the script against an actual AceCode record."""
ds = datasets.load_dataset("TIGER-Lab/AceCode-87K", split="train")
ds = datasets.load_dataset(
"TIGER-Lab/AceCode-87K", split="train", num_proc=open_instruct_utils.max_num_processes()
)

# Choose the same sample index used in the original snippet.
i = 1
Expand Down
40 changes: 31 additions & 9 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
)
from transformers.utils.hub import _CACHED_NO_EXIST, TRANSFORMERS_CACHE, extract_commit_hash, try_to_load_from_cache

from open_instruct.utils import hf_whoami
from open_instruct.utils import hf_whoami, max_num_processes


# ----------------------------------------------------------------------------
Expand Down Expand Up @@ -1379,16 +1379,25 @@ def __post_init__(self):
# if the file exists locally, use the local file
if os.path.exists(self.dataset_name) and self.dataset_name.endswith(".jsonl"):
assert self.dataset_split == "train", "Only train split is supported for local jsonl files."
self.dataset = load_dataset("json", data_files=self.dataset_name, split=self.dataset_split)
self.dataset = load_dataset(
"json", data_files=self.dataset_name, split=self.dataset_split, num_proc=max_num_processes()
)
elif os.path.exists(self.dataset_name) and self.dataset_name.endswith(".parquet"):
assert self.dataset_split == "train", "Only train split is supported for local parquet files."
self.dataset = load_dataset("parquet", data_files=self.dataset_name, split=self.dataset_split)
self.dataset = load_dataset(
"parquet", data_files=self.dataset_name, split=self.dataset_split, num_proc=max_num_processes()
)
else:
# commit hash only works for hf datasets
self.dataset_commit_hash = get_commit_hash(
self.dataset_name, self.dataset_revision, "README.md", "dataset"
)
self.dataset = load_dataset(self.dataset_name, split=self.dataset_split, revision=self.dataset_revision)
self.dataset = load_dataset(
self.dataset_name,
split=self.dataset_split,
revision=self.dataset_revision,
num_proc=max_num_processes(),
)
if self.dataset_range is None:
dataset_range = len(self.dataset)
self.update_range(dataset_range)
Expand Down Expand Up @@ -1512,7 +1521,12 @@ def load_or_transform_dataset(
print("dataset_skip_cache is True, so we will not load the dataset from cache")
else:
# Use the split from the first dataset config as default
return load_dataset(repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=self.config_hash)
return load_dataset(
repo_name,
split=DEFAULT_SPLIT_FOR_CACHED_DATASET,
revision=self.config_hash,
num_proc=max_num_processes(),
)

print(f"Cache not found, transforming datasets...")

Expand Down Expand Up @@ -1565,7 +1579,9 @@ def load_or_transform_dataset(

# NOTE: Load the dataset again to make sure it's downloaded to the HF cache
print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{self.config_hash}")
return load_dataset(repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=self.config_hash)
return load_dataset(
repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=self.config_hash, num_proc=max_num_processes()
)


class LocalDatasetTransformationCache:
Expand Down Expand Up @@ -1931,7 +1947,9 @@ def test_get_cached_dataset_tulu_sft():
dataset_skip_cache=True,
)

gold_tokenized_dataset = load_dataset("allenai/dataset-mix-cached", split="train", revision="61ac38e052")
gold_tokenized_dataset = load_dataset(
"allenai/dataset-mix-cached", split="train", revision="61ac38e052", num_proc=max_num_processes()
)
assert len(dataset) == len(gold_tokenized_dataset)
for i in range(len(dataset)):
assert dataset[i]["input_ids"] == gold_tokenized_dataset[i]["input_ids"]
Expand Down Expand Up @@ -1959,7 +1977,9 @@ def test_get_cached_dataset_tulu_preference():
TOKENIZED_PREFERENCE_DATASET_KEYS,
dataset_skip_cache=True,
)
gold_tokenized_dataset = load_dataset("allenai/dataset-mix-cached", split="train", revision="9415479293")
gold_tokenized_dataset = load_dataset(
"allenai/dataset-mix-cached", split="train", revision="9415479293", num_proc=max_num_processes()
)
assert len(dataset) == len(gold_tokenized_dataset)
for i in range(len(dataset)):
assert dataset[i]["chosen_input_ids"] == gold_tokenized_dataset[i]["chosen_input_ids"]
Expand Down Expand Up @@ -1987,7 +2007,9 @@ def test_get_cached_dataset_tulu_rlvr():
transform_fn_args,
dataset_skip_cache=True,
)
gold_tokenized_dataset = load_dataset("allenai/dataset-mix-cached", split="train", revision="0ff0043e56")
gold_tokenized_dataset = load_dataset(
"allenai/dataset-mix-cached", split="train", revision="0ff0043e56", num_proc=max_num_processes()
)
assert len(dataset) == len(gold_tokenized_dataset)
for i in range(len(dataset)):
assert dataset[i][INPUT_IDS_PROMPT_KEY] == gold_tokenized_dataset[i][INPUT_IDS_PROMPT_KEY]
Expand Down
18 changes: 13 additions & 5 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@
DataClassType = NewType("DataClassType", Any)


def max_num_processes() -> int:
"""Returns a reasonable default number of processes to run for multiprocessing."""
if hasattr(os, "sched_getaffinity"):
return len(os.sched_getaffinity(0))
else:
return os.cpu_count() or 1


def repeat_each(seq, k):
"""Repeat each element in a sequence k times."""
return [item for item in seq for _ in range(k)]
Expand Down Expand Up @@ -315,13 +323,13 @@ def get_datasets(
for split in splits:
# if dataset ends with .json or .jsonl, load from file
if ds.endswith(".json") or ds.endswith(".jsonl"):
dataset = load_dataset("json", data_files=ds, split=split)
dataset = load_dataset("json", data_files=ds, split=split, num_proc=max_num_processes())
elif ds.endswith(".parquet"):
dataset = load_dataset("parquet", data_files=ds, split=split)
dataset = load_dataset("parquet", data_files=ds, split=split, num_proc=max_num_processes())
else:
try:
# Try first if dataset on a Hub repo
dataset = load_dataset(ds, ds_config, split=split)
dataset = load_dataset(ds, ds_config, split=split, num_proc=max_num_processes())
except DatasetGenerationError:
# If not, check local dataset
dataset = load_from_disk(os.path.join(ds, split))
Expand Down Expand Up @@ -529,11 +537,11 @@ def combine_dataset(
for (ds, frac_or_samples), ds_config, split in zip(dataset_mixer.items(), configs, splits):
# if dataset ends with .json or .jsonl, load from file
if ds.endswith(".json") or ds.endswith(".jsonl"):
dataset = load_dataset("json", data_files=ds, split=split)
dataset = load_dataset("json", data_files=ds, split=split, num_proc=max_num_processes())
else:
try:
# Try first if dataset on a Hub repo
dataset = load_dataset(ds, ds_config, split=split)
dataset = load_dataset(ds, ds_config, split=split, num_proc=max_num_processes())
except DatasetGenerationError:
# If not, check local dataset
dataset = load_from_disk(os.path.join(ds, split))
Expand Down
6 changes: 4 additions & 2 deletions quantize/quantize_autogptq_wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from datasets import load_dataset
from transformers import AutoTokenizer

import open_instruct.utils as open_instruct_utils


def get_wikitext2(nsamples, seed, seqlen, model):
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train", num_proc=open_instruct_utils.max_num_processes())
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", num_proc=open_instruct_utils.max_num_processes())

tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
Expand Down
13 changes: 7 additions & 6 deletions scripts/create_ground_truth_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random

from datasets import Dataset, load_dataset
import open_instruct.utils as open_instruct_utils
from tqdm import tqdm

from open_instruct.math_utils import last_boxed_only_string, remove_boxed
Expand Down Expand Up @@ -86,7 +87,7 @@
for sample in GSM8K_EXEMPLARS:
gsm8k_prompt += f"Question: {sample['question'].strip()}\nAnswer:{sample['cot_answer'].strip()}\n\n"

gsm8k_dataset = load_dataset("gsm8k", "main", split="train")
gsm8k_dataset = load_dataset("gsm8k", "main", split="train", num_proc=open_instruct_utils.max_num_processes())
new_data = []
for sample in gsm8k_dataset:
answer = sample["answer"].split("####")[-1].strip()
Expand All @@ -97,7 +98,7 @@
})

# also make a test split for eval
gsm8k_dataset = load_dataset("gsm8k", "main", split="test")
gsm8k_dataset = load_dataset("gsm8k", "main", split="test", num_proc=open_instruct_utils.max_num_processes())
test_data = []
for sample in gsm8k_dataset:
answer = sample["answer"].split("####")[-1].strip()
Expand All @@ -111,7 +112,7 @@
math_prompt = ""
for sample in MATH_EXAMPLARS:
math_prompt += f"Question: {sample['question'].strip()}\nAnswer:{sample['cot_answer'].strip()}\n\n"
math_dataset = load_dataset("lighteval/MATH", "all", split="train")
math_dataset = load_dataset("lighteval/MATH", "all", split="train", num_proc=open_instruct_utils.max_num_processes())
for sample in math_dataset:
# same code used to extract answer for eval
answer = remove_boxed(last_boxed_only_string(sample["solution"]))
Expand All @@ -132,7 +133,7 @@
# dataset.push_to_hub("ai2-adapt-dev/gsm8k_math_ground_truth")

# # alternate dataset: metamathqa!
# metamathqa_dataset = load_dataset("meta-math/MetaMathQA", "main", split="train")
# metamathqa_dataset = load_dataset("meta-math/MetaMathQA", "main", split="train", num_proc=open_instruct_utils.max_num_processes())
# # let's re-use the MATH prompt.
# new_data = []
# def extract_answer(text):
Expand All @@ -158,7 +159,7 @@
# dataset.push_to_hub("ai2-adapt-dev/metamathqa_ground_truth")

# alternate dataset: numina-tir
metamathqa_dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train")
metamathqa_dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train", num_proc=open_instruct_utils.max_num_processes())
# let's re-use the MATH prompt.
new_data = []
def find_last_outermost_boxed(string):
Expand Down Expand Up @@ -209,7 +210,7 @@ def find_last_outermost_boxed(string):
dataset.push_to_hub("ai2-adapt-dev/numinamath_tir_ground_truth_one_turn")

# alternate dataset: numina-cot (much, much larger)
metamathqa_dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train")
metamathqa_dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train", num_proc=open_instruct_utils.max_num_processes())
# let's re-use the MATH prompt.
new_data = []
for sample in tqdm(metamathqa_dataset):
Expand Down
2 changes: 1 addition & 1 deletion scripts/data/azure_batch/process_azure_batch_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def process_batch_results(
max_rows: Optional[int] = None,
):
# Load the original dataset first so we can look up failed prompts
original_ds = datasets.load_dataset(input_dataset, split=split)
original_ds = datasets.load_dataset(input_dataset, split=split, num_proc=max_num_processes())
id_lookup = {row["id"]: row for row in original_ds}

all_batch_results = {}
Expand Down
2 changes: 1 addition & 1 deletion scripts/data/azure_batch/regenerate_dataset_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def main(sample_limit: int | None = None,
os.makedirs(f"{current_dir}/batch_files", exist_ok=True)

print(f"Loading dataset {input_dataset_name} with split {split}")
input_dataset = datasets.load_dataset(input_dataset_name, split=split)
input_dataset = datasets.load_dataset(input_dataset_name, split=split, num_proc=max_num_processes())

# First get all unique IDs
print(f'Processing dataset with {len(input_dataset)} rows')
Expand Down
3 changes: 2 additions & 1 deletion scripts/data/build_hardcoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial

from datasets import DatasetDict, load_dataset
import open_instruct.utils as open_instruct_utils
from huggingface_hub import HfApi

from open_instruct import logger_utils
Expand Down Expand Up @@ -281,7 +282,7 @@ def main():
# --- Load Source Dataset ---
try:
logger.info(f"Loading source dataset '{args.source_repo}'...")
original_dataset = load_dataset(args.source_repo)
original_dataset = load_dataset(args.source_repo, num_proc=open_instruct_utils.max_num_processes())
logger.info(f"Dataset loaded successfully. Splits: {list(original_dataset.keys())}")
except Exception as e:
logger.error(f"Failed to load source dataset '{args.source_repo}': {e}")
Expand Down
3 changes: 2 additions & 1 deletion scripts/data/convert_general_thought_to_tulu_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import random

from datasets import Dataset, load_dataset
import open_instruct.utils as open_instruct_utils

random_gen = random.Random(42)

ds = load_dataset("natolambert/GeneralThought-430K-filtered", split="train")
ds = load_dataset("natolambert/GeneralThought-430K-filtered", split="train", num_proc=open_instruct_utils.max_num_processes())
new_data = []

for sample in ds:
Expand Down
11 changes: 6 additions & 5 deletions scripts/data/create_deepscaler_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import random

from datasets import Dataset, load_dataset
import open_instruct.utils as open_instruct_utils
from tqdm import tqdm

random_gen = random.Random(42)

dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train")
dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", num_proc=open_instruct_utils.max_num_processes())

# reqular dataset
new_data = []
Expand All @@ -19,7 +20,7 @@
dataset = Dataset.from_list(new_data)
dataset.push_to_hub("ai2-adapt-dev/deepscaler-gt")

dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train")
dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", num_proc=open_instruct_utils.max_num_processes())
# 4k length only
new_data = []
for sample in tqdm(dataset):
Expand All @@ -33,7 +34,7 @@
dataset = Dataset.from_list(new_data)
dataset.push_to_hub("ai2-adapt-dev/deepscaler_gt_random_max_length_only")

dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train")
dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", num_proc=open_instruct_utils.max_num_processes())
# 4k length and solution
new_data = []
for sample in tqdm(dataset):
Expand All @@ -47,7 +48,7 @@
dataset = Dataset.from_list(new_data)
dataset.push_to_hub("ai2-adapt-dev/deepscaler_gt_with_random_max_length")

dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train")
dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", num_proc=open_instruct_utils.max_num_processes())
# 8k length only
new_data = []
for sample in tqdm(dataset):
Expand All @@ -61,7 +62,7 @@
dataset = Dataset.from_list(new_data)
dataset.push_to_hub("ai2-adapt-dev/deepscaler_gt_random_max_length_only_8192")

dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train")
dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", num_proc=open_instruct_utils.max_num_processes())
# 8k length and solution
new_data = []
for sample in tqdm(dataset):
Expand Down
Loading