-
Notifications
You must be signed in to change notification settings - Fork 467
[PyTorch] Enable generic QK norm support (+ RMSNorm/LayerNorm) #1966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Evgeny <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for looking into this so quickly. This all looks good, but:
I have to admit that there is one more thing that I just noticed. The Qwen3 models apply QK normalization before RoPE, see here, this is in contrast to this implementation which is based on LLama4.
I was not aware that there are two different formulations for this. Sorry for that.
Signed-off-by: Evgeny <[email protected]>
Signed-off-by: Evgeny <[email protected]>
Signed-off-by: Evgeny <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <[email protected]>
This flexibility might be worth it; I supported it. |
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, that was super quick! This looks very good to me 🥳
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
For the training stabilization purposes, QK tensors might be normalized.
RMSNorm
/LayerNorm
as a normalization types (in addition toL2Normalization
).RoPE
(following both Qwen and Llama approaches).Extention of #1864
Fixes #1958
Type of change
Checklist: