Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix FlashAttentionKwargs RoPE #35941

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

garrett361
Copy link

@garrett361 garrett361 commented Jan 28, 2025

What does this PR do?

#33932 introduced FlashAttentionKwargs as an alternative to using position_ids for padding-free training. The RoPE positional embedding are not currently applied correctly in the FlashAttentionKwargs code path. This PR ensures that RoPE is used properly for this path.

Code Notes

The Issue

The issue is that if position_ids not provided, then they are internally generated here:

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

and these are used to generate the rope embeddings here:

# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

These rope embeddings are ~ torch.arange, whereas they should be non-trivially generated from the values in FlashAttentionKwargs

The Fix

Introduce a get_position_ids_from_cu_seq_lens helper which coverts from FlashAttentionKwargs -> position_ids, when provided.

Because many other models inherit from LlamaDecoder, this change propagates changes to many other models via modular_model_converter.py.

Tests

The solution is tested in LlamaModelTest::test_attn_mask_position_ids_flash_attn_equality, which checks that logits in the follow cases are consistent with each other:

  • No padding-free, just padding and attention masks
  • Padding free via position_ids
  • Padding free via FlashAttentionKwargs

This test fails on latest main without the above fix.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@garrett361 garrett361 marked this pull request as draft January 28, 2025 17:35
@Rocketknight1
Copy link
Member

cc @Abhishek-TAMU @ArthurZucker because this is an update to #33932

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants