-
Notifications
You must be signed in to change notification settings - Fork 29
Allow loss masking for defined spans of characters #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Looks good so far, but can you please add a short description and/or point to an issue? |
fast_llm/functional/cross_entropy.py
Outdated
@@ -41,6 +51,10 @@ def fused_cross_entropy_forward_backward( | |||
""" | |||
# Do the forward and backward passes all at once, and fused with dtype conversion. | |||
# Way faster and more memory-efficient than the pytorch version. | |||
if apply_loss_mask: | |||
loss_mask = target != ignore_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we first integrating the loss mask with the target tensor and then extract it here again? does this save any noteworthy amount of memory? why not passing an explicit and optional binary loss mask to the cross entropy function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the idea was to let each implementation handle loss-masking and we only provide the indices to ignore (following torch's cross-entropy).
Do you prefer creating the loss mask along with mask indices here, and let each implementation decide whether they want to use that or ignore_index
? Not sure if there's a benefit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to delay this PR unnecessarily. If you and @jlamypoirier think that the combo of ignore indices in the target
and apply_loss_mask
is the right interface for the cross entropy loss, then let's leave it like that.
[86, 89, 22255, 1073, 79, 480], | ||
[86, 49152, 89, 22255, 1073, 79], | ||
[8008, 498, 71, 727, 80, 315], | ||
[2210, 8179, 73, 2582, 897, 1178], | ||
[86, 89, 88, 87, 409, 70], | ||
[86, 83, 744, 89, 64, 333], | ||
[86, 89, 1461, 87, 330, 7876], | ||
[86, 89, 88, 49152, 87, 49152], | ||
[86, 49152, 83, 744, 89, 64], | ||
[86, 89, 1461, 49152, 87, 49152], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jlamypoirier can you pls check if this is good with you?
I'm not sure if this is ok for FIM.
It changed since we're now specifying begin
and end
in the tokenizer. The test tokenizer doesn't add either by default so we didn't see 49152
earlier. It is both bod_id
and eod_id
for that tokenizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a few small changes, should be good to merge if nobody opposes:
- Simplified the cross-entropy thing. I had a deeper look and determined that the best option is to standardize all methods and always enable masking for index < 0, which makes all the
apply_loss_mask
arguments unnecessary.. The performance impact won't be a real problem, it was already on for triton which is the one we care about the most, and for the compiled ones the overhead will probably be made negligible with compilation. Only remaining issue is torch cross-entropy having the different behaviour of masking -100 only. - Simplified the tests
- Replaced the random spans in tests with a much faster method
- Fixed some other tests.
@sohamparikh, can you run two experiments with SFT data, one with masking turned on and one with masking turned off? I'd like us to see what the performance impact of masking is (tokens/s/gpu), if any. |
ok, so 13% performance loss with masking turned on. |
@tscholak pls disregard my plot for fused cross-entropy above, I've deleted it. Found a minor bug in fused (and the previous runs are unreliable). Will re-run and post again |
smaller vocab because of #52? how small did you go? We are forced to use the fused implementation for our models right now, and the penalty there is 13% according to your earlier plot? |
@tscholak Can you post more details about the 13%? I set the masking to be always on, so that looks really bad |
Pls allow me to clear the confusion I created with the multiple (now deleted) plots. The numbers are generally quite similar both with and without loss-masking irrespective of the cross-entropy kernel used.
So, good news: no 13% penalty All the runs are in this wandb project: |
Ok looks good then. The first run looks like a data loader bottleneck, second has it in local filesystem cache so that's normal. We don't care much about masked vs not masked for the model since it's running the same code now. What we care about is the possible regression for fused without loss masking, and the lack of difference with triton seem to show it's not significant. |
In this cases shouldn't we compare with |
Ideally yes, but that wouldn't be that useful because the triton implementation is unchanged so we can deduce the performance in |
Confirmed that the fused numbers with main are also similar. I'll merge this now |
✨ Description
Support loss masking for spans specified in the input data. This PR will ensure that loss will not be computed on the specified spans. The biggest use-case for this is instruction tuning data where we want to avoid training on the prompts.
Closes #109
📝 Changes
List the key changes introduced in this PR:
🔍 Type of change