Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e8bda1c
add sft argument and preliminary implementation of SFTIndexedDataset
RaphaelKreft Oct 2, 2025
c9b5701
make SFTIndexedDataset align with GPTDataset.
RaphaelKreft Oct 2, 2025
5bfcacb
Implement masking of user sequences
RaphaelKreft Oct 4, 2025
5a5b5de
Debug SFTIndexed Dataset
Oct 5, 2025
6e458a8
add masking for attention rows for padding tokens
RaphaelKreft Oct 5, 2025
4ff1170
Remove special tokens from user start/end sequence
Oct 5, 2025
df80bc2
end truncated samples with eod token
RaphaelKreft Oct 6, 2025
912d683
Add option to mask special token(sequences) in sft dataset and improv…
RaphaelKreft Oct 7, 2025
c591b9a
debugged sdt and tokenizer. added options to not mask image tokens
Oct 11, 2025
812ea3e
fix: use labels to calculate values for loss mask NOT tokens
RaphaelKreft Oct 13, 2025
bc021e1
implement right padding, add debug flag, remove goldfish loss, cleanup
RaphaelKreft Oct 14, 2025
49a4e8a
remove caching as it potentially can cause issues. Add cmd arg for plw
RaphaelKreft Oct 16, 2025
96ea32d
Add tokenizer properties for pre-tokenized SFT sequences
Alvorecer721 Oct 16, 2025
9f14f5b
Use pre-tokenized sequences from tokenizer config
Alvorecer721 Oct 16, 2025
d832f6e
add support for plw in sft_dataset.py.
RaphaelKreft Oct 16, 2025
74baf69
fix: add sft_plw to gpt dataset config properly
Oct 16, 2025
7b81b95
add assistant loss logging for sft (untested)
RaphaelKreft Oct 17, 2025
515977b
improve assistant loss logging to work with plw=1 (untested)
RaphaelKreft Oct 20, 2025
1ee7c2e
fix assistant mask not passed though callchain correctly
Oct 20, 2025
56fbc8c
implement sample-packing for sft dataset (untested)
RaphaelKreft Oct 21, 2025
e14e9cf
add and use python-version of build_packed_idx method in sft_dataset …
RaphaelKreft Oct 21, 2025
0fcbc7a
Fix minor issues to make packing training launch and compile
Oct 22, 2025
846ab71
add sft init script that can init sft dataset once to print packed sa…
RaphaelKreft Oct 22, 2025
3693f1e
log packing statistics also if loaded from cache
RaphaelKreft Oct 22, 2025
572b92d
remove unecessary prints from init sft script. Sft dset prints sttist…
Oct 22, 2025
07b185c
Add --final-checkpoint arg, that enables storing a final checkpoint r…
RaphaelKreft Oct 23, 2025
4f1ee44
Add option to skip skip margin samples, to exactly reach target numbe…
RaphaelKreft Oct 24, 2025
b89133c
Add option to skip skip margin samples, to exactly reach target numbe…
RaphaelKreft Oct 24, 2025
217f3ea
Minor fixes to sft-dataset
Oct 24, 2025
c5eef18
re-enable forward pre-hook before final checkpoint
RaphaelKreft Oct 26, 2025
ba0e5c6
activate c++ index building helper, deprecate python one
RaphaelKreft Oct 28, 2025
53b4619
remove not-mask img tokens, add sample averaging and multi-epoch supp…
RaphaelKreft Nov 6, 2025
4192389
use correct shuffle in single doc idx building
Nov 6, 2025
129d45a
extend key conf attributes of sft_dataset to account for packing
RaphaelKreft Nov 6, 2025
8383953
cleanup old python index building implementation. Limit sample index …
RaphaelKreft Nov 7, 2025
1ea97b4
store correct size of sample index and mitigate NaN when equalizing s…
RaphaelKreft Nov 7, 2025
bf5d2b7
move zero-out of loss-mask into respective method
RaphaelKreft Nov 7, 2025
da8cb3b
fix attempt NaN loss
RaphaelKreft Nov 7, 2025
ab43c71
sft-dataset: obtain sample boundaries during low level sample loading…
RaphaelKreft Nov 17, 2025
bf74020
add load loss mask from disk option - cleanup sft dataset
RaphaelKreft Nov 19, 2025
8c5188f
minor fixes to sft_dataset.py after refactor
Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions initialize_sft_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
#!/usr/bin/env python
"""Initialize SFT dataset with packing to determine sample counts.

This script initializes the SFT dataset in the same way as pretrain_gpt.py
but exits immediately after dataset initialization. This is useful for the
first phase of packed SFT training where you need to determine the actual
number of packed samples without loading the full model.

IMPORTANT REQUIREMENTS:
- SEED MUST MATCH your intended training run (determines packing)
- Parallelism settings (TP/PP/EP) MUST match your actual training run
- World size can be smaller (minimal DP) but TP/PP/EP must be identical
- Model architecture params are only needed for validation (not used in packing)

Usage:
python initialize_sft_dataset.py <same arguments as pretrain_gpt.py>

Must include: --sft --sft-pack-samples

Example (for a training run with TP=8, PP=4):
torchrun --nproc_per_node=32 initialize_sft_dataset.py \\
--tensor-model-parallel-size 8 \\
--pipeline-model-parallel-size 4 \\
--num-layers 32 \\
--hidden-size 4096 \\
--seq-length 2048 \\
--data-path /path/to/data \\
--tokenizer-type GPT2BPETokenizer \\
--vocab-file /path/to/vocab.json \\
--merge-file /path/to/merges.txt \\
--sft \\
--sft-pack-samples \\
--train-iters 1000 \\
--global-batch-size 8

# Note: nproc_per_node = TP * PP = 32 (world size can be = TP*PP*minimal_DP)
"""

import sys
from typing import List, Optional, Tuple

from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_tokenizer
from megatron.core import mpu
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
from megatron.training.initialize import initialize_megatron
from megatron.training.utils import get_blend_and_blend_per_split


def is_dataset_built_on_rank():
"""Determine if dataset should be built on this rank."""
return (
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
) and mpu.get_tensor_model_parallel_rank() == 0


def core_gpt_dataset_config_from_args(args):
"""Create GPTDatasetConfig from command line arguments."""
tokenizer = get_tokenizer()

# Sometimes --data-path is too long, instead we parse it from a file.
blend: Optional[Tuple[List[str], Optional[List[float]]]]
blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
blend, blend_per_split = get_blend_and_blend_per_split(args)

# Double sequence length if loading loss masks from disk (dataset stores tokens + loss_mask concatenated)
sequence_length = args.seq_length
if args.sft and args.sft_load_loss_mask:
sequence_length = args.seq_length * 2
print_rank_0(f"> SFT: Loading loss masks from disk, doubling dataset sequence_length to {sequence_length} "
f"(model will see {args.seq_length} tokens)")

return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=sequence_length,
blend=blend,
blend_per_split=blend_per_split,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
mmap_bin_files=args.mmap_bin_files,
tokenizer=tokenizer,
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path=args.s3_cache_path,
goldfish_loss=args.goldfish_loss,
goldfish_k=args.goldfish_k,
goldfish_h=args.goldfish_h,
sft_mask_special_tokens=args.sft_mask_special_tokens,
sft_plw=args.sft_plw,
sft_pack_samples=args.sft_pack_samples,
sft_equalize_sample_loss=args.sft_equalize_sample_loss,
sft_load_loss_mask=args.sft_load_loss_mask,
sft_disable_assistant_mask=args.sft_disable_assistant_mask,
skip_margin_samples=args.data_skip_margin_samples
)


def build_train_valid_test_datasets(train_val_test_num_samples):
"""Build the train, test, and validation datasets.

Args:
train_val_test_num_samples: A list containing the number of samples in train, test, and validation.

Returns:
train_ds, valid_ds, test_ds: The constructed datasets
"""
args = get_args()

config = core_gpt_dataset_config_from_args(args)

if args.sft:
from megatron.core.datasets.sft_dataset import SFTIndexedDataset
dataset_type = SFTIndexedDataset
elif args.mock_data:
dataset_type = MockGPTDataset
else:
dataset_type = GPTDataset

print_rank_0("> building train, validation, and test datasets for GPT ...")

train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type,
train_val_test_num_samples,
is_dataset_built_on_rank,
config
).build()

print_rank_0("> finished creating GPT datasets ...")


def get_train_val_test_num_samples():
"""
Calculate the number of samples for train/val/test datasets.
The numbers here are not important for sample packing calculation but needed for init.
"""
args = get_args()

# From training.py build_train_valid_test_datasets function
# We need to determine train_val_test_num_samples

# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size

eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [
train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size
]

return train_val_test_num_samples


def main():
"""Main function to initialize dataset and exit."""

# Initialize Megatron (this handles argument parsing and distributed setup)
initialize_megatron(
extra_args_provider=None,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
)

args = get_args()

# Validate required arguments
if not args.sft or not args.sft_pack_samples:
print_rank_0("=" * 80)
print_rank_0("ERROR: This script requires both --sft and --sft-pack-samples flags")
print_rank_0("=" * 80)
print_rank_0("This script is specifically for determining packed sample counts")
print_rank_0("before running a full SFT training job with sample packing.")
print_rank_0("")
print_rank_0("Usage:")
print_rank_0(" python initialize_sft_dataset.py <args> --sft --sft-pack-samples")
print_rank_0("")
print_rank_0("For normal training without packing, use pretrain_gpt.py")
print_rank_0("=" * 80)
sys.exit(1)

print_rank_0("=" * 80)
print_rank_0("SFT Dataset Initialization Script")
print_rank_0("This script will build the dataset index and report packed sample counts")
print_rank_0("=" * 80)
print_rank_0("")
print_rank_0("IMPORTANT: SEED must match your intended training run! It determines packing and thus num of samples.")
print_rank_0(f" Seed: {args.seed}")
print_rank_0("=" * 80)
print_rank_0("")
print_rank_0("IMPORTANT: Parallelism settings (TP/PP/EP) should match your training run!")
print_rank_0(f" Tensor Parallel: {args.tensor_model_parallel_size}")
print_rank_0(f" Pipeline Parallel: {args.pipeline_model_parallel_size}")
if args.expert_model_parallel_size > 1:
print_rank_0(f" Expert Parallel: {args.expert_model_parallel_size}")
print_rank_0(f" World Size: {args.world_size}")
print_rank_0("")
print_rank_0("Note: Model architecture parameters (--num-layers, --hidden-size, etc.)")
print_rank_0(" are only needed to pass Megatron's validation. They don't affect")
print_rank_0(" the dataset packing calculation, which only depends on:")
print_rank_0(" - Data paths and tokenizer")
print_rank_0(" - Sequence length (--seq-length)")
print_rank_0(" - Global batch size (--global-batch-size)")
print_rank_0(" - Parallelism settings (TP/PP/EP)")
print_rank_0("=" * 80)
print_rank_0("")

# Calculate train/val/test sample counts
train_val_test_num_samples = get_train_val_test_num_samples()

# Build datasets (this will trigger the packing process)
build_train_valid_test_datasets(train_val_test_num_samples)


if __name__ == "__main__":
main()
10 changes: 8 additions & 2 deletions megatron/core/datasets/blended_megatron_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,11 @@ def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]:
# The number of samples we plan to use per dataset
sizes_per_dataset_target = _get_size_per_split_per_dataset(weights, self.sizes)
# The number of samples we plan to build per dataset
# Skip margin if requested (useful for fixed-size datasets like SFT packed)
# Otherwise use margin=0.5% to ensure we have enough samples
margin = 0.0 if self.config.skip_margin_samples else 0.5
sizes_per_dataset_buffer = _get_size_per_split_per_dataset(
weights, self.sizes, margin=0.5
weights, self.sizes, margin=margin
)

# Build each dataset in parallel
Expand Down Expand Up @@ -297,8 +300,11 @@ def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]:
weights, sizes_spoof
)
# The number of samples we plan to build per dataset
# Skip margin if requested (useful for fixed-size datasets like SFT packed)
# Otherwise use margin=0.5% to ensure we have enough samples
margin = 0.0 if self.config.skip_margin_samples else 0.5
sizes_per_dataset_buffer = _get_size_per_split_per_dataset(
weights, sizes_spoof, margin=0.5
weights, sizes_spoof, margin=margin
)

# Build each dataset in parallel
Expand Down
5 changes: 5 additions & 0 deletions megatron/core/datasets/blended_megatron_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ class BlendedMegatronDatasetConfig:
tokenizer: Optional[MegatronTokenizer] = None
"""The MegatronTokenizer instance. Required for datasets that do online tokenization."""

skip_margin_samples: bool = False
"""Whether to skip the 0.5% margin when calculating dataset sizes. When True, requests
exactly the number of samples needed without buffer. Useful for datasets with fixed
size (e.g., SFT packed datasets) where samples cannot be regenerated."""

def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
if self.blend_per_split is not None and any(self.blend_per_split):
Expand Down
12 changes: 12 additions & 0 deletions megatron/core/datasets/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
s3_cache_path: str = None
"""Path for caching indices for s3 dataloading."""

"""
SFT Options
"""
sft_mask_special_tokens: bool = True # Mask EOD, BOD and assistant sequence begin tokens, NOT end of turn
sft_plw: float = 0.0 # prompt loss weight used
sft_pack_samples: bool = False # Enable packing multiple whole documents per sequence for SFT
sft_equalize_sample_loss: bool = False # loss between samples will be equal
sft_load_loss_mask: bool = False # Load pre-computed loss masks from disk alongside tokens
sft_disable_assistant_mask: bool = False # Disable assistant mask computation (set to None)

def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
super().__post_init__()
Expand Down Expand Up @@ -256,13 +266,15 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]:
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
"assistant_mask": None, # Not used for standard GPT training
}
else:
return {
"tokens": tokens,
"labels": labels,
"loss_mask": loss_mask,
"position_ids": position_ids,
"assistant_mask": None, # Not used for standard GPT training
}

def _query_document_sample_shuffle_indices(
Expand Down
89 changes: 89 additions & 0 deletions megatron/core/datasets/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,93 @@ py::array_t<T> build_sample_idx(
);
}

template <typename DocIdx>
py::array_t<DocIdx> build_sample_idx_packed_whole_docs(
const py::array_t<int32_t>& sizes_,
const py::array_t<DocIdx>& doc_idx_,
int32_t seq_length,
int32_t add_extra_token_to_sequence) {
/* Build the sample index for SFT with whole-document packing.
* Pack whole documents into sequences until the next document doesn't fit.
* Never split documents across sequences - pad the remainder instead.
*
* Returns a 2D array [num_samples + 1, 2] where each row contains:
* [document_idx_index, 0] (offset is always 0 since we only use whole docs)
*/

// Ensure the input arrays are not empty
if (sizes_.size() == 0) {
throw std::domain_error("sizes_ is empty");
}
if (doc_idx_.size() == 0) {
throw std::domain_error("doc_idx_ is empty");
}

// Get buffer pointers
const int32_t* sizes = sizes_.data();
const DocIdx* doc_idx = doc_idx_.data(); // doc_idx keeps track of list of document indices

// Calculate adjusted sequence length
int32_t adjusted_seq_length = seq_length + add_extra_token_to_sequence;

// Count the number of samples we can create
std::vector<std::pair<DocIdx, DocIdx>> sample_starts;
sample_starts.reserve(doc_idx_.size()); // estimate

// Iterate through documents and pack them
DocIdx doc_idx_index = 0;
while (doc_idx_index < doc_idx_.size()) {
// Start a new sample
sample_starts.push_back({doc_idx_index, 0});
int32_t remaining_seq_length = adjusted_seq_length;

// keep track of number of added documents to seq. If the 1st doc is too long, it is kept, so it can be truncated or discarded in client code
int32_t documents_in_sample = 0;

// Pack whole documents into this sequence
while (doc_idx_index < doc_idx_.size()) {
DocIdx doc_id = doc_idx[doc_idx_index];
int32_t doc_length = sizes[doc_id];

// Check if document fits in remaining space
if (doc_length <= remaining_seq_length) {
// Document fits - include it
remaining_seq_length -= doc_length;
doc_idx_index++;
documents_in_sample++;
} else {
// Document doesn't fit
if (documents_in_sample == 0) {
// if it was the 1st doc that doesnt fit, use it anyway and go to next sample. Leave truncation or discard to client code
// other case: spare this doc for next sample (dont increase doc idx)
doc_idx_index++;
}
break;
}

// If we've packed exactly to the sequence length, move to next sequence
if (remaining_seq_length == 0) {
break;
}
}
}

// Add the final boundary marker
sample_starts.push_back({doc_idx_index, 0});

// Convert to numpy array [num_samples + 1, 2]
size_t num_samples = sample_starts.size();
auto result = py::array_t<DocIdx>({num_samples, size_t(2)});
auto r = result.template mutable_unchecked<2>();

for (size_t i = 0; i < num_samples; ++i) {
r(i, 0) = sample_starts[i].first; // document_idx_index
r(i, 1) = sample_starts[i].second; // offset (always 0)
}

return result;
}

inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length,
std::mt19937 &rand32_gen)
Expand Down Expand Up @@ -841,6 +928,8 @@ PYBIND11_MODULE(helpers_cpp, m)
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx_int32", &build_sample_idx<int32_t>);
m.def("build_sample_idx_int64", &build_sample_idx<int64_t>);
m.def("build_sample_idx_packed_whole_docs_int32", &build_sample_idx_packed_whole_docs<int32_t>);
m.def("build_sample_idx_packed_whole_docs_int64", &build_sample_idx_packed_whole_docs<int64_t>);
m.def("build_blending_indices", &build_blending_indices);
m.def("build_exhaustive_blending_indices", &build_exhaustive_blending_indices);
}
Loading