Skip to content

Allow loss masking for defined spans of characters #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Feb 7, 2025

Conversation

sohamparikh
Copy link
Member

@sohamparikh sohamparikh commented Jan 14, 2025

✨ Description

Support loss masking for spans specified in the input data. This PR will ensure that loss will not be computed on the specified spans. The biggest use-case for this is instruction tuning data where we want to avoid training on the prompts.

Closes #109

📝 Changes

List the key changes introduced in this PR:

  • Support character spans as inputs specified in the prepare command
  • Read the spans during training and apply masks to cross-entropy loss

🔍 Type of change

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@jlamypoirier
Copy link
Collaborator

Looks good so far, but can you please add a short description and/or point to an issue?

@sohamparikh sohamparikh changed the title convert character spans to token spans Allow loss masking for defined spans of characters Jan 24, 2025
@sohamparikh sohamparikh marked this pull request as ready for review January 28, 2025 08:19
@sohamparikh sohamparikh marked this pull request as draft January 28, 2025 08:29
@ServiceNow ServiceNow deleted a comment from sohampnow Feb 5, 2025
@@ -41,6 +51,10 @@ def fused_cross_entropy_forward_backward(
"""
# Do the forward and backward passes all at once, and fused with dtype conversion.
# Way faster and more memory-efficient than the pytorch version.
if apply_loss_mask:
loss_mask = target != ignore_index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we first integrating the loss mask with the target tensor and then extract it here again? does this save any noteworthy amount of memory? why not passing an explicit and optional binary loss mask to the cross entropy function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the idea was to let each implementation handle loss-masking and we only provide the indices to ignore (following torch's cross-entropy).

Do you prefer creating the loss mask along with mask indices here, and let each implementation decide whether they want to use that or ignore_index? Not sure if there's a benefit

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to delay this PR unnecessarily. If you and @jlamypoirier think that the combo of ignore indices in the target and apply_loss_mask is the right interface for the cross entropy loss, then let's leave it like that.

Comment on lines -16 to +21
[86, 89, 22255, 1073, 79, 480],
[86, 49152, 89, 22255, 1073, 79],
[8008, 498, 71, 727, 80, 315],
[2210, 8179, 73, 2582, 897, 1178],
[86, 89, 88, 87, 409, 70],
[86, 83, 744, 89, 64, 333],
[86, 89, 1461, 87, 330, 7876],
[86, 89, 88, 49152, 87, 49152],
[86, 49152, 83, 744, 89, 64],
[86, 89, 1461, 49152, 87, 49152],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlamypoirier can you pls check if this is good with you?
I'm not sure if this is ok for FIM.

It changed since we're now specifying begin and end in the tokenizer. The test tokenizer doesn't add either by default so we didn't see 49152 earlier. It is both bod_id and eod_id for that tokenizer

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! thanks!

Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a few small changes, should be good to merge if nobody opposes:

  • Simplified the cross-entropy thing. I had a deeper look and determined that the best option is to standardize all methods and always enable masking for index < 0, which makes all the apply_loss_mask arguments unnecessary.. The performance impact won't be a real problem, it was already on for triton which is the one we care about the most, and for the compiled ones the overhead will probably be made negligible with compilation. Only remaining issue is torch cross-entropy having the different behaviour of masking -100 only.
  • Simplified the tests
  • Replaced the random spans in tests with a much faster method
  • Fixed some other tests.

@tscholak
Copy link
Collaborator

tscholak commented Feb 7, 2025

@sohamparikh, can you run two experiments with SFT data, one with masking turned on and one with masking turned off? I'd like us to see what the performance impact of masking is (tokens/s/gpu), if any.

@tscholak
Copy link
Collaborator

tscholak commented Feb 7, 2025

ok, so 13% performance loss with masking turned on.
Are we happy with that? Can we expect better? @jlamypoirier?
I was hoping for virtually no penalty, but not sure how realistic that is.

@sohamparikh
Copy link
Member Author

@tscholak pls disregard my plot for fused cross-entropy above, I've deleted it. Found a minor bug in fused (and the previous runs are unreliable). Will re-run and post again

@tscholak
Copy link
Collaborator

tscholak commented Feb 7, 2025

with cross_entropy_impl: triton they're pretty similar (smaller vocab size this time)

smaller vocab because of #52? how small did you go? We are forced to use the fused implementation for our models right now, and the penalty there is 13% according to your earlier plot?

@jlamypoirier
Copy link
Collaborator

@tscholak Can you post more details about the 13%? I set the masking to be always on, so that looks really bad

@sohamparikh
Copy link
Member Author

sohamparikh commented Feb 7, 2025

Pls allow me to clear the confusion I created with the multiple (now deleted) plots. The numbers are generally quite similar both with and without loss-masking irrespective of the cross-entropy kernel used.
However, I do see that the first run is (almost) always slower than the second with the exact same config (using a different run directory, and hence a different dataset cache as well).

Model Loss Mask Cross-Entropy Vocab Size Run 1 (tok/s/gpu) Run 2 (tok/s/gpu)
SLAM-5.1B No Triton 49k 18700 18700
SLAM-5.1B Yes Triton 49k 16200 18700
SLAM-5.1B No Fused 49k 16200 18650
SLAM-5.1B Yes Fused 49k 16200 18700
SLAM-5.1B No Triton 131k Crash Crash
SLAM-5.1B Yes Triton 131k Crash Crash
SLAM-5.1B No Fused 131k 14900 17400
SLAM-5.1B Yes Fused 131k 14900 17350

So, good news: no 13% penalty

All the runs are in this wandb project:
https://wandb.ai/akshaykalkunte/slam

@jlamypoirier
Copy link
Collaborator

Ok looks good then. The first run looks like a data loader bottleneck, second has it in local filesystem cache so that's normal. We don't care much about masked vs not masked for the model since it's running the same code now. What we care about is the possible regression for fused without loss masking, and the lack of difference with triton seem to show it's not significant.

@tscholak
Copy link
Collaborator

tscholak commented Feb 7, 2025

In this cases shouldn't we compare with main?

@jlamypoirier
Copy link
Collaborator

jlamypoirier commented Feb 7, 2025

In this cases shouldn't we compare with main?

Ideally yes, but that wouldn't be that useful because the triton implementation is unchanged so we can deduce the performance in main for these results.

@sohamparikh
Copy link
Member Author

sohamparikh commented Feb 7, 2025

Confirmed that the fused numbers with main are also similar.
https://wandb.ai/akshaykalkunte/slam/runs/ohq71uk8v1btd8sc

I'll merge this now

@sohamparikh sohamparikh merged commit 954fde1 into main Feb 7, 2025
3 of 4 checks passed
@sohamparikh sohamparikh deleted the soham/loss-masking-spans branch February 7, 2025 22:55
@jlamypoirier jlamypoirier mentioned this pull request Mar 27, 2025
14 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Implement Loss Masking to Exclude Predefined Token Spans from LM Loss
3 participants