Skip to content

Conversation

@RaphaelKreft
Copy link

Implementation of a SFT Dataset for SFT Training in Megatron. Developed originally for Visual Instruction Tuning.

  • As the gpt_dataset it uses an IndexedDataset as Low-Level Dataset.
  • From there it loads pre-tokenized sft data, then masks the user prompts (in our case including image tokens).
  • Per training sample: loads a single sample from indexed dataset, then pads to maximum sequence length

Added

  • --sft cli argument (when given, use the new SFTIndexedDataset)
  • SFTIndexedDataset loading and preparing pre-tokenized sft data

Questions / Todos

  • Add dynamic loading of "begin of user prompt" and "end-of-turn" sequence (could be part of tokenizer, is currently hard-coded and thus only works for LLama3 Vision Model Chat templates)
  • Implement option to mask loss on all special tokens (BOS, EOD, EOS, SFT-Related special tokens)??

return matches


def get_matching_mask_by_start_end(tokens, begin_seq: torch.Tensor, end_seq: torch.Tensor):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRUCIAL BUG: We're Searching in the Wrong Place

Let’s make this concrete with a simple example.

Sequence:  a b c d e

Tokens:    a b c d
Labels:    b c d e

Each loss_mask[i] controls whether we train on predicting labels[i] from tokens[i], i.e.:

  • loss_mask[0]: a → b
  • loss_mask[1]: b → c
  • loss_mask[2]: c → d
  • loss_mask[3]: d → e

The Problem

Say we want to mask the prediction of 'c' (don’t train the model to predict 'c').

Current (incorrect) logic:

  • We search for 'c' in tokens → find index 2
  • Set loss_mask[2] = 0

Result:

a → b   (trained)
b → c   (trained)   ❌ we wanted to mask this one
c → d   (masked)
d → e   (trained)

This disables the c→d prediction instead of b→c.


Correct logic

We should search for 'c' in labels, not tokens:

  • 'c' is at label index 1
  • Set loss_mask[1] = 0

Result:

a → b   (trained)
b → c   (masked)    ✅ correct
c → d   (trained)
d → e   (trained)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now renamed the arguments to the sequence matching methods (to "data" and "sequence") and more importantly give them the labels to calculate the mask.

end_len = len(end_seq)

if 0 < begin_len <= len(tokens):
matches_begin = get_matching_mask(tokens, begin_seq, only_begin=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same applies here, get_matching_mask should work on the labels rather than tokens.

Imagine the sequence: "<|start_header_id|>Assistant<|end_header_id|> Hi"

We want to enable the loss from <|end_header_id|> → Hi, as the model should learn to predict “Hi” from <|end_header_id|>.

"position_ids": position_ids,
}

def _get_ltor_masks_and_position_ids(self, tokens,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer using data as in the original _get_ltor_masks_and_position_ids against tokens here, in the original implementation tokens are passed here to prevent the generation after eod token in tokens during pretraining, for SFT maybe we should pass in labels instead.

if not only_begin:
matches_float = matches.float().unsqueeze(0).unsqueeze(0) # (1, 1, N)
kernel = torch.ones(1, 1, query_len, device=sequence.device)
expanded = F.conv1d(matches_float, kernel, padding=query_len - 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think using convid here for padding is a bit overkill? @TJ-Solergibert

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it can be, but as far as 1. Doesn't hurts performance and 2. Everyone is comfortable with it, it's fine.

Keep in mind that this function is performed 1. In the CPU 2. While the GPU is processing the previous batch, so as far as you don't hit any CPU OOM error and you are not bottlenecked by the DataLoader you are good. To check for the later just compare the throughput when using mock data.

RaphaelKreft and others added 18 commits October 13, 2025 22:55
- Load sequences from tokenizer properties instead of tokenizing at runtime
- Pre-compute token sequences as tensors in __init__
- Use .to() instead of torch.tensor() in hot path for efficiency
- Reduces overhead during training data loading
…mples and then exit. Remove dummy packing arg and code from main code. (untested)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants