-
Notifications
You must be signed in to change notification settings - Fork 19
Multimodality/sft extension #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: multimodality/images
Are you sure you want to change the base?
Conversation
…e efficiency of loss-mask creation.
| return matches | ||
|
|
||
|
|
||
| def get_matching_mask_by_start_end(tokens, begin_seq: torch.Tensor, end_seq: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRUCIAL BUG: We're Searching in the Wrong Place
Let’s make this concrete with a simple example.
Sequence: a b c d e
Tokens: a b c d
Labels: b c d e
Each loss_mask[i] controls whether we train on predicting labels[i] from tokens[i], i.e.:
loss_mask[0]: a → bloss_mask[1]: b → closs_mask[2]: c → dloss_mask[3]: d → e
The Problem
Say we want to mask the prediction of 'c' (don’t train the model to predict 'c').
Current (incorrect) logic:
- We search for
'c'in tokens → find index 2 - Set
loss_mask[2] = 0
Result:
a → b (trained)
b → c (trained) ❌ we wanted to mask this one
c → d (masked)
d → e (trained)
This disables the c→d prediction instead of b→c.
Correct logic
We should search for 'c' in labels, not tokens:
'c'is at label index 1- Set
loss_mask[1] = 0
Result:
a → b (trained)
b → c (masked) ✅ correct
c → d (trained)
d → e (trained)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now renamed the arguments to the sequence matching methods (to "data" and "sequence") and more importantly give them the labels to calculate the mask.
| end_len = len(end_seq) | ||
|
|
||
| if 0 < begin_len <= len(tokens): | ||
| matches_begin = get_matching_mask(tokens, begin_seq, only_begin=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same applies here, get_matching_mask should work on the labels rather than tokens.
Imagine the sequence: "<|start_header_id|>Assistant<|end_header_id|> Hi"
We want to enable the loss from <|end_header_id|> → Hi, as the model should learn to predict “Hi” from <|end_header_id|>.
| "position_ids": position_ids, | ||
| } | ||
|
|
||
| def _get_ltor_masks_and_position_ids(self, tokens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer using data as in the original _get_ltor_masks_and_position_ids against tokens here, in the original implementation tokens are passed here to prevent the generation after eod token in tokens during pretraining, for SFT maybe we should pass in labels instead.
| if not only_begin: | ||
| matches_float = matches.float().unsqueeze(0).unsqueeze(0) # (1, 1, N) | ||
| kernel = torch.ones(1, 1, query_len, device=sequence.device) | ||
| expanded = F.conv1d(matches_float, kernel, padding=query_len - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think using convid here for padding is a bit overkill? @TJ-Solergibert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it can be, but as far as 1. Doesn't hurts performance and 2. Everyone is comfortable with it, it's fine.
Keep in mind that this function is performed 1. In the CPU 2. While the GPU is processing the previous batch, so as far as you don't hit any CPU OOM error and you are not bottlenecked by the DataLoader you are good. To check for the later just compare the throughput when using mock data.
- Load sequences from tokenizer properties instead of tokenizing at runtime - Pre-compute token sequences as tensors in __init__ - Use .to() instead of torch.tensor() in hot path for efficiency - Reduces overhead during training data loading
…for now (untested)
…mples and then exit. Remove dummy packing arg and code from main code. (untested)
…egardless of checkpoint interval
…ort to sft-dataset.
…to samples needed.
… not form eod indices
Implementation of a SFT Dataset for SFT Training in Megatron. Developed originally for Visual Instruction Tuning.
Added
--sftcli argument (when given, use the new SFTIndexedDataset)SFTIndexedDatasetloading and preparing pre-tokenized sft dataQuestions / Todos