-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve the efficiency of the RMSNorm aggregation #179
Comments
i would like to try this one, #take @lancerts, edit: not sure how well i can do this, but would like to try, so idk if good to assign to me |
you can refer to layernorm in triton tutorial |
The https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py uses the atomic operation. We can just use similar method in https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/layer_norm.py#L106, which does not use atomic operation. |
yes, i did go similar to liger layer_norm implementation, however i'm having some issues with numerical stability, is it ok to increase the absolute tolerance in tests? I'll create a draft PR asap, need to fix some issues @lancerts |
🚀 The feature, motivation and pitch
Modify this line https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rms_norm.py#L306, the sum in pytorch to partial aggregation in triton, reference
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/layer_norm.py#L106,
which does 2 level of aggregation, first in triton and second in torch (more efficient).
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: