-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Optimize the Cuda Kernel performance of Paddle rms_norm #77098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize the Cuda Kernel performance of Paddle rms_norm #77098
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #77098 +/- ##
===========================================
Coverage ? 100.00%
===========================================
Files ? 4
Lines ? 37
Branches ? 0
===========================================
Hits ? 37
Misses ? 0
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
0604b69 to
65aff3e
Compare
This reverts commit ccaaa1b.
65aff3e to
19513e8
Compare
This reverts commit 19513e8.
wanghuancoder
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
A-nnonymous
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in kernel in most case, but need polish afterwards.
| } | ||
| inline __device__ res_t project(acc_t acc) const { | ||
| const auto mean = static_cast<scalar_t>(acc.mean); | ||
| const auto divisor = acc.nf > correction ? acc.nf - correction : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之后这里用显式类型声明,避免使用auto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| if (threadIdx.x == 0) { | ||
| T_ACC m1; // mean | ||
| T_ACC m2; // var | ||
| thrust::pair<T_ACC, T_ACC> res = welford_op.project(val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在将来的修改中,尽量避免引入thrust相关的数据结构与算法,方便移植
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| int64_t thread_x = static_cast<int64_t>(blockIdx.x) * block_dim_x + | ||
| static_cast<int64_t>(threadIdx.x); | ||
|
|
||
| int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lane_id有更高效的求取方式,优化时需要注意这部分开销
zrr1999
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
python/paddle/nn/functional/norm.py
Outdated
| def rms_norm( | ||
| input: Tensor, | ||
| normalized_shape: int | Sequence[int], | ||
| normalized_shape: list[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用Sequence[int]吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/phi/ops/yaml/ops.yaml
Outdated
|
|
||
| - op: rms_norm | ||
| args: (Tensor x, Tensor scale, float epsilon) | ||
| args: (Tensor x, Tensor scale, IntArray normalized_shape={}, double epsilon= 1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normalized_shape 如果 python api 明确是 int 类型,这里可以使用 int64_t[]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
XiaoguangHu01
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…#77098) * accuracy and Torch alignment * support rms_norm behavior to be the same as torch * fix rms_norm_xpu_kernel * add valueError_test * Revert "add valueError_test" This reverts commit ccaaa1b. * Reapply "add valueError_test" This reverts commit 19513e8. * optimize performance * add vectorization * fix * fix dtype of normalized_shape
PR Category
Operator Mechanism
PR Types
Improvements
Description
优化rms_norm
a. 【完成】精度:已与torch逐位对齐(fp16、bf16、fp32、fp64)
b. 【完成】功能:原来Paddle rms_norm缺少normalized_shape参数,且不支持weight为空,已与Torch对齐
c. 【完成】性能:目前前向、反向基本与Torch持平
d. 【完成】性能2:API前处理可优化,kernel内可进行向量化加长、分支和融合等优化,目前性能全面优于torch
rms_norm前向:(优化前后,相比Paddle旧的实现平均提升2.2倍,相比Torch平均提升6%)


rms_norm反向:(优化前后,相比Paddle旧的实现平均提升1.6倍,相比Torch平均提升31%)