Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
#33932 introduced
FlashAttentionKwargs
as an alternative to usingposition_ids
for padding-free training. The RoPE positional embedding are not currently applied correctly in theFlashAttentionKwargs
code path. This PR ensures that RoPE is used properly for this path.Code Notes
The Issue
The issue is that if
position_ids
not provided, then they are internally generated here:transformers/src/transformers/models/llama/modeling_llama.py
Lines 561 to 562 in ec7afad
and these are used to generate the rope embeddings here:
transformers/src/transformers/models/llama/modeling_llama.py
Lines 570 to 571 in ec7afad
These rope embeddings are
~ torch.arange
, whereas they should be non-trivially generated from the values inFlashAttentionKwargs
The Fix
Introduce a
get_position_ids_from_cu_seq_lens
helper which coverts fromFlashAttentionKwargs -> position_ids
, when provided.Because many other models inherit from
LlamaDecoder
, this change propagates changes to many other models viamodular_model_converter.py
.Tests
The solution is tested in
LlamaModelTest::test_attn_mask_position_ids_flash_attn_equality
, which checks that logits in the follow cases are consistent with each other:position_ids
FlashAttentionKwargs
This test fails on latest
main
without the above fix.Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.