Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e8bda1c
add sft argument and preliminary implementation of SFTIndexedDataset
RaphaelKreft Oct 2, 2025
c9b5701
make SFTIndexedDataset align with GPTDataset.
RaphaelKreft Oct 2, 2025
5bfcacb
Implement masking of user sequences
RaphaelKreft Oct 4, 2025
5a5b5de
Debug SFTIndexed Dataset
Oct 5, 2025
6e458a8
add masking for attention rows for padding tokens
RaphaelKreft Oct 5, 2025
4ff1170
Remove special tokens from user start/end sequence
Oct 5, 2025
df80bc2
end truncated samples with eod token
RaphaelKreft Oct 6, 2025
912d683
Add option to mask special token(sequences) in sft dataset and improv…
RaphaelKreft Oct 7, 2025
c591b9a
debugged sdt and tokenizer. added options to not mask image tokens
Oct 11, 2025
812ea3e
fix: use labels to calculate values for loss mask NOT tokens
RaphaelKreft Oct 13, 2025
bc021e1
implement right padding, add debug flag, remove goldfish loss, cleanup
RaphaelKreft Oct 14, 2025
49a4e8a
remove caching as it potentially can cause issues. Add cmd arg for plw
RaphaelKreft Oct 16, 2025
96ea32d
Add tokenizer properties for pre-tokenized SFT sequences
Alvorecer721 Oct 16, 2025
9f14f5b
Use pre-tokenized sequences from tokenizer config
Alvorecer721 Oct 16, 2025
d832f6e
add support for plw in sft_dataset.py.
RaphaelKreft Oct 16, 2025
74baf69
fix: add sft_plw to gpt dataset config properly
Oct 16, 2025
7b81b95
add assistant loss logging for sft (untested)
RaphaelKreft Oct 17, 2025
515977b
improve assistant loss logging to work with plw=1 (untested)
RaphaelKreft Oct 20, 2025
1ee7c2e
fix assistant mask not passed though callchain correctly
Oct 20, 2025
56fbc8c
implement sample-packing for sft dataset (untested)
RaphaelKreft Oct 21, 2025
e14e9cf
add and use python-version of build_packed_idx method in sft_dataset …
RaphaelKreft Oct 21, 2025
0fcbc7a
Fix minor issues to make packing training launch and compile
Oct 22, 2025
846ab71
add sft init script that can init sft dataset once to print packed sa…
RaphaelKreft Oct 22, 2025
3693f1e
log packing statistics also if loaded from cache
RaphaelKreft Oct 22, 2025
572b92d
remove unecessary prints from init sft script. Sft dset prints sttist…
Oct 22, 2025
07b185c
Add --final-checkpoint arg, that enables storing a final checkpoint r…
RaphaelKreft Oct 23, 2025
4f1ee44
Add option to skip skip margin samples, to exactly reach target numbe…
RaphaelKreft Oct 24, 2025
b89133c
Add option to skip skip margin samples, to exactly reach target numbe…
RaphaelKreft Oct 24, 2025
217f3ea
Minor fixes to sft-dataset
Oct 24, 2025
c5eef18
re-enable forward pre-hook before final checkpoint
RaphaelKreft Oct 26, 2025
ba0e5c6
activate c++ index building helper, deprecate python one
RaphaelKreft Oct 28, 2025
53b4619
remove not-mask img tokens, add sample averaging and multi-epoch supp…
RaphaelKreft Nov 6, 2025
4192389
use correct shuffle in single doc idx building
Nov 6, 2025
129d45a
extend key conf attributes of sft_dataset to account for packing
RaphaelKreft Nov 6, 2025
8383953
cleanup old python index building implementation. Limit sample index …
RaphaelKreft Nov 7, 2025
1ea97b4
store correct size of sample index and mitigate NaN when equalizing s…
RaphaelKreft Nov 7, 2025
bf5d2b7
move zero-out of loss-mask into respective method
RaphaelKreft Nov 7, 2025
da8cb3b
fix attempt NaN loss
RaphaelKreft Nov 7, 2025
ab43c71
sft-dataset: obtain sample boundaries during low level sample loading…
RaphaelKreft Nov 17, 2025
bf74020
add load loss mask from disk option - cleanup sft dataset
RaphaelKreft Nov 19, 2025
8c5188f
minor fixes to sft_dataset.py after refactor
Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 193 additions & 63 deletions megatron/core/datasets/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
import logging
import numpy as np
import json
from pathlib import Path
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -33,14 +35,16 @@ def __init__(
# Call Megatron Dataset init instead of direct parent, as we initialize index differently
MegatronDataset.__init__(self, dataset, dataset_path, indexed_indices, num_samples, index_split, config)

self.debug_writer = DebugDataWriter(output_dir="/users/rkreft/debug_data") # Initialize once, outside the loop

self.tokenizer = config.tokenizer
# Set pad token
try:
self._pad_token_id = self.tokenizer.pad
except Exception:
self._pad_token_id = _PAD_TOKEN_ID

# End of Document token to add end to truncated samples
# End of Document token to add end to truncated samples TODO: currently works with HF tokenizers only
self._eod_token_id = self.tokenizer.eod
self._bos_token_id = self.tokenizer.bos
# TODO: Pass sequences dynamically
Expand All @@ -49,13 +53,18 @@ def __init__(
self._sft_turn_end_sequence = self.tokenizer.tokenize('<|eot_id|>', add_special_tokens=False)
self._sft_assistant_begin_sequence = self.tokenizer.tokenize('<|start_header_id|>assistant<|end_header_id|>',
add_special_tokens=False)
# TODO: Make multimodality optional - configurable or just leave this as long as there is only one option?
self._img_begin_sequence = self.tokenizer.tokenize('<|img_start|>', add_special_tokens=False)
self._img_end_sequence = self.tokenizer.tokenize('<|img_end|>', add_special_tokens=False)

# Configure token(sequences to remove from loss calculations)
self.tokens_to_mask = [] # a list of: token ids or sequences of token ids to mask
if self.config.sft_mask_special_tokens:
# add tokenizer special tokens like EOS, BOS to be masked
self.tokens_to_mask += list(self.tokenizer.special_tokens)
self.tokens_to_mask.append(self._sft_user_assistant_sequence)
logger.warning(f"Masking the following tokens/token-sequences: {self.tokens_to_mask}")
self.tokens_to_mask.append([self._eod_token_id])
self.tokens_to_mask.append([self._bos_token_id])
self.tokens_to_mask.append(self._sft_assistant_begin_sequence)
log_single_rank(logger, logging.WARNING, f"Masking the following tokens/token-sequences: {self.tokens_to_mask}", )

# Build shuffle indices
self.document_index = self._build_single_document_indices()
Expand Down Expand Up @@ -212,7 +221,8 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]:
text = np.concatenate([trunc_doc, np.array([self._eod_token_id])])
else:
padding_length = target_length - len(document)
text = np.concatenate([document, np.full(padding_length, self._pad_token_id, dtype=np.int64)])
# Pad on left side with pad token
text = np.concatenate([np.full(padding_length, self._pad_token_id, dtype=np.int64), document])

text = torch.from_numpy(text).long()

Expand Down Expand Up @@ -250,17 +260,39 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]:
loss_mask[labels == self._pad_token_id] = 0.0

# DEBUG: Log how many tokens are actually being trained on
#num_unmasked = loss_mask.sum().item()
#total_tokens = loss_mask.numel()
#if idx is not None and idx % 100 == 0: # Log every 100 samples
# logger.warning(f"Sample {idx}: {num_unmasked}/{total_tokens} tokens unmasked "
# f"({100*num_unmasked/total_tokens:.1f}%), "
# f"doc_length={len(torch.from_numpy(document))}, "
# f"num_pad_tokens={(labels == self._pad_token_id).sum().item()}")
# user_begin_seq = torch.tensor(self._sft_user_begin_sequence, dtype=tokens.dtype, device=tokens.device)
# logger.warning(f"Sample {idx}: Looking for user_begin pattern: {user_begin_seq.tolist()}")
# logger.warning(f"Sample {idx}: Token sequence sample: {tokens[:100].tolist()}")
# logger.warning(f"Sample {idx}: Loss mask sample: {loss_mask[:100].tolist()}")
num_unmasked = loss_mask.sum().item()
total_tokens = loss_mask.numel()

if False and idx is not None and idx % 100 == 0: # Log every 100 samples
logger.warning(f"Sample {idx} - DOC {actual_doc_id}: {num_unmasked}/{total_tokens} tokens unmasked "
f"({100*num_unmasked/total_tokens:.1f}%), "
f"doc_length={len(torch.from_numpy(document))}, "
f"num_pad_tokens={(labels == self._pad_token_id).sum().item()}")

# Store to files
self.debug_writer.append_sample(
idx=idx,
actual_doc_id=actual_doc_id,
tokens=tokens,
loss_mask=loss_mask,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels,
document_length=len(torch.from_numpy(document)),
pad_token_id=self._pad_token_id
)

# Continue with existing logging
user_begin_seq = torch.tensor(self._sft_user_begin_sequence, dtype=tokens.dtype, device=tokens.device)
logger.warning(f"Sample {idx}: Looking for user_begin pattern: {user_begin_seq.tolist()}")
logger.warning(f"Sample {idx}: Token sequence sample: {tokens[:100].tolist()}")
logger.warning(f"Sample {idx}: Loss mask sample: {loss_mask[:100].tolist()}")
logger.warning(f"Sample {idx}: Loss mask sample last 100: {loss_mask[-100:].tolist()}")
logger.warning(f"Sample {idx}: Position ids sample: {position_ids[:100].tolist()}")
if attention_mask is not None:
logger.warning(f"Sample {idx}: Attention mask sample (1=masked): {attention_mask[0,0,:100].long().tolist()}")
# END DEBUG


# Map pad tokens to valid embedding indices
tokens[tokens == self._pad_token_id] = 0
Expand Down Expand Up @@ -315,50 +347,8 @@ def _get_ltor_masks_and_position_ids(self, tokens,
begin_seq = torch.tensor(self._sft_user_begin_sequence, dtype=tokens.dtype, device=tokens.device)
end_seq = torch.tensor(self._sft_turn_end_sequence, dtype=tokens.dtype, device=tokens.device)

begin_len = len(begin_seq)
end_len = len(end_seq)

if 0 < begin_len <= len(tokens):
matches_begin = get_matching_mask(tokens, begin_seq, only_begin=True)

if end_len > 0:
matches_end = get_matching_mask(tokens, end_seq, only_begin=True)

begin_indices = torch.where(matches_begin)[0]
end_indices = torch.where(matches_end)[0]

# Vectorized masking
if len(begin_indices) > 0 and len(end_indices) > 0:
# For each begin, find the next ends (vectorized)
end_matrix = end_indices.unsqueeze(0) > begin_indices.unsqueeze(1)
has_valid_end = end_matrix.any(dim=1)
first_end_idx = end_matrix.int().argmax(dim=1)

# Compute end positions for each begin
end_positions = torch.where(
has_valid_end,
end_indices[first_end_idx] + end_len,
len(loss_mask)
)

# Create ranges and mask in one go, Shape: (num_begins, max_range_len)
max_len = (end_positions - begin_indices).max().item()
ranges = torch.arange(max_len, device=tokens.device).unsqueeze(0)
lengths = (end_positions - begin_indices).unsqueeze(1)

# Get all indices to mask
mask_positions = begin_indices.unsqueeze(1) + ranges
valid_mask = ranges < lengths
indices_to_mask = mask_positions[valid_mask]

loss_mask[indices_to_mask] = 0.0
elif len(begin_indices) > 0:
# No end sequences, mask from each begin to the end
max_len = len(loss_mask) - begin_indices.min().item()
ranges = torch.arange(max_len, device=tokens.device).unsqueeze(0)
mask_positions = begin_indices.unsqueeze(1) + ranges
valid = mask_positions < len(loss_mask)
loss_mask[mask_positions[valid]] = 0.0
user_seq_mask = get_matching_mask_by_start_end(tokens, begin_seq, end_seq)
loss_mask[user_seq_mask] = 0.0

# 2) Mask other token(sequences) as configured in init (might contain BOS, EOS, assistant begin)
for t in self.tokens_to_mask:
Expand All @@ -371,19 +361,27 @@ def _get_ltor_masks_and_position_ids(self, tokens,
raise ValueError(f"Invalid token to mask: {t}")
loss_mask[mask] = 0.0

# 3) Unmask Image tokens if configured
if self.config.sft_do_not_mask_image_tokens:
img_begin_tensor = torch.tensor(self._img_begin_sequence, dtype=tokens.dtype, device=tokens.device)
img_end_tensor = torch.tensor(self._img_end_sequence, dtype=tokens.dtype, device=tokens.device)
img_seq_mask = get_matching_mask_by_start_end(tokens, img_begin_tensor, img_end_tensor)
loss_mask[img_seq_mask] = 1.0

# 4) Create attention mask, excluding padding tokens from attention
if create_attention_mask:
# Here me mask attention from all padding tokens to all other tokens and vice versa
attention_mask = torch.tril(
torch.ones((self.config.sequence_length, self.config.sequence_length), device=tokens.device)
)
# Mask padding tokens in attention mask:
#no_padding_mask = (tokens != self._pad_token_id).float() # 1=real, 0=padding
no_padding_mask = (tokens != self._pad_token_id).float() # 1=real, 0=padding

# Mask both rows (queries from padding) and columns (keys to padding)
# Row masking: padding tokens shouldn't attend to anything
#attention_mask = attention_mask * no_padding_mask.unsqueeze(1)
attention_mask = attention_mask * no_padding_mask.unsqueeze(1)
# Column masking: nothing should attend to padding tokens
#attention_mask = attention_mask * no_padding_mask.unsqueeze(0)
attention_mask = attention_mask * no_padding_mask.unsqueeze(0)

# Convert attention mask to binary:
attention_mask = attention_mask.unsqueeze(0)
Expand Down Expand Up @@ -419,3 +417,135 @@ def get_matching_mask(sequence, query: torch.Tensor, only_begin:bool=True):
return matches


def get_matching_mask_by_start_end(tokens, begin_seq: torch.Tensor, end_seq: torch.Tensor):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRUCIAL BUG: We're Searching in the Wrong Place

Let’s make this concrete with a simple example.

Sequence:  a b c d e

Tokens:    a b c d
Labels:    b c d e

Each loss_mask[i] controls whether we train on predicting labels[i] from tokens[i], i.e.:

  • loss_mask[0]: a → b
  • loss_mask[1]: b → c
  • loss_mask[2]: c → d
  • loss_mask[3]: d → e

The Problem

Say we want to mask the prediction of 'c' (don’t train the model to predict 'c').

Current (incorrect) logic:

  • We search for 'c' in tokens → find index 2
  • Set loss_mask[2] = 0

Result:

a → b   (trained)
b → c   (trained)   ❌ we wanted to mask this one
c → d   (masked)
d → e   (trained)

This disables the c→d prediction instead of b→c.


Correct logic

We should search for 'c' in labels, not tokens:

  • 'c' is at label index 1
  • Set loss_mask[1] = 0

Result:

a → b   (trained)
b → c   (masked)    ✅ correct
c → d   (trained)
d → e   (trained)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now renamed the arguments to the sequence matching methods (to "data" and "sequence") and more importantly give them the labels to calculate the mask.

"""
Given a sequence and a start and end query, return a mask indicating which positions in the sequence
are between the start and end queries (inclusive).
"""
mask = torch.zeros(len(tokens), dtype=torch.bool, device=tokens.device)
begin_len = len(begin_seq)
end_len = len(end_seq)

if 0 < begin_len <= len(tokens):
matches_begin = get_matching_mask(tokens, begin_seq, only_begin=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same applies here, get_matching_mask should work on the labels rather than tokens.

Imagine the sequence: "<|start_header_id|>Assistant<|end_header_id|> Hi"

We want to enable the loss from <|end_header_id|> → Hi, as the model should learn to predict “Hi” from <|end_header_id|>.


if end_len > 0:
matches_end = get_matching_mask(tokens, end_seq, only_begin=True)

begin_indices = torch.where(matches_begin)[0]
end_indices = torch.where(matches_end)[0]

# Vectorized masking
if len(begin_indices) > 0 and len(end_indices) > 0:
# For each begin, find the next ends (vectorized)
end_matrix = end_indices.unsqueeze(0) > begin_indices.unsqueeze(1)
has_valid_end = end_matrix.any(dim=1)
first_end_idx = end_matrix.int().argmax(dim=1)

# Compute end positions for each begin
end_positions = torch.where(
has_valid_end,
end_indices[first_end_idx] + end_len,
len(mask)
)

# Create ranges and mask in one go, Shape: (num_begins, max_range_len)
max_len = (end_positions - begin_indices).max().item()
ranges = torch.arange(max_len, device=tokens.device).unsqueeze(0)
lengths = (end_positions - begin_indices).unsqueeze(1)

# Get all indices to mask
mask_positions = begin_indices.unsqueeze(1) + ranges
valid_mask = ranges < lengths
indices_to_mask = mask_positions[valid_mask]

mask[indices_to_mask] = True
elif len(begin_indices) > 0:
# No end sequences, mask from each begin to the end
max_len = len(mask) - begin_indices.min().item()
ranges = torch.arange(max_len, device=tokens.device).unsqueeze(0)
mask_positions = begin_indices.unsqueeze(1) + ranges
valid = mask_positions < len(mask)
mask[mask_positions[valid]] = True
return mask


class DebugDataWriter:
"""Helper class to write debug data to files."""

def __init__(self, output_dir="debug_data"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
self.metadata_file = self.output_dir / "metadata.jsonl"
self.tokens_file = self.output_dir / "tokens.npy"
self.loss_mask_file = self.output_dir / "loss_mask.npy"
self.attention_mask_file = self.output_dir / "attention_mask.npy"

# Initialize files if they don't exist
if not self.tokens_file.exists():
self._init_array_file(self.tokens_file)
if not self.loss_mask_file.exists():
self._init_array_file(self.loss_mask_file)
if not self.attention_mask_file.exists():
self._init_array_file(self.attention_mask_file)

def _init_array_file(self, filepath):
"""Initialize an empty numpy array file."""
np.save(filepath, np.array([]))

def append_sample(self, idx, actual_doc_id, tokens, loss_mask,
attention_mask=None, position_ids=None,
labels=None, document_length=None, pad_token_id=None):
"""Append a sample to the debug files."""

# Convert tensors to numpy
tokens_np = tokens.cpu().numpy() if torch.is_tensor(tokens) else tokens
loss_mask_np = loss_mask.cpu().numpy() if torch.is_tensor(loss_mask) else loss_mask

# Load existing data
tokens_data = np.load(self.tokens_file, allow_pickle=True)
loss_mask_data = np.load(self.loss_mask_file, allow_pickle=True)

# Append new data
if tokens_data.size == 0:
tokens_data = np.array([tokens_np], dtype=object)
loss_mask_data = np.array([loss_mask_np], dtype=object)
else:
tokens_data = np.append(tokens_data, [tokens_np])
loss_mask_data = np.append(loss_mask_data, [loss_mask_np])

# Save updated arrays
np.save(self.tokens_file, tokens_data)
np.save(self.loss_mask_file, loss_mask_data)

# Handle attention mask
if attention_mask is not None:
attention_mask_np = attention_mask.cpu().numpy() if torch.is_tensor(attention_mask) else attention_mask
attention_mask_data = np.load(self.attention_mask_file, allow_pickle=True)

if attention_mask_data.size == 0:
attention_mask_data = np.array([attention_mask_np], dtype=object)
else:
attention_mask_data = np.append(attention_mask_data, [attention_mask_np])

np.save(self.attention_mask_file, attention_mask_data)

# Save metadata
num_unmasked = loss_mask_np.sum()
total_tokens = loss_mask_np.size

metadata = {
"sample_idx": int(idx) if idx is not None else None,
"doc_id": int(actual_doc_id) if actual_doc_id is not None else None,
"num_unmasked": int(num_unmasked),
"total_tokens": int(total_tokens),
"unmasked_percentage": float(100 * num_unmasked / total_tokens),
"document_length": int(document_length) if document_length is not None else None,
"num_pad_tokens": int((labels == pad_token_id).sum()) if labels is not None and pad_token_id is not None else None,
"sequence_length": int(len(tokens_np)),
"has_attention_mask": attention_mask is not None
}

# Append metadata as JSONL
with open(self.metadata_file, 'a') as f:
f.write(json.dumps(metadata) + '\n')
4 changes: 4 additions & 0 deletions megatron/training/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def offsets(self, ids: list[int], text: str) -> list[int]:
@property
def eod(self):
return self._tokenizer.eos_token_id

@property
def bos(self):
return self._tokenizer.bos_token_id


class _BertWordPieceTokenizer(MegatronTokenizer):
Expand Down