Skip to content

Conversation

@zhengshengning
Copy link
Contributor

@zhengshengning zhengshengning commented Dec 25, 2025

PR Category

Operator Mechanism

PR Types

Improvements

Description

优化rms_norm
a. 【完成】精度:已与torch逐位对齐(fp16、bf16、fp32、fp64)
b. 【完成】功能:原来Paddle rms_norm缺少normalized_shape参数,且不支持weight为空,已与Torch对齐
c. 【完成】性能:目前前向、反向基本与Torch持平
d. 【完成】性能2:API前处理可优化,kernel内可进行向量化加长、分支和融合等优化,目前性能全面优于torch

rms_norm前向:(优化前后,相比Paddle旧的实现平均提升2.2倍,相比Torch平均提升6%)
image
rms_norm反向:(优化前后,相比Paddle旧的实现平均提升1.6倍,相比Torch平均提升31%)
image

@paddle-bot
Copy link

paddle-bot bot commented Dec 25, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@codecov-commenter
Copy link

codecov-commenter commented Dec 26, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@20ae51f). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #77098   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         4           
  Lines              ?        37           
  Branches           ?         0           
===========================================
  Hits               ?        37           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@zrr1999 zrr1999 force-pushed the opt_rms_norm branch 2 times, most recently from 0604b69 to 65aff3e Compare December 30, 2025 12:36
@zhengshengning zhengshengning changed the title accuracy and Torch alignment Optimize the Cuda Kernel performance of Paddle rms_norm Dec 31, 2025
wanghuancoder
wanghuancoder previously approved these changes Jan 4, 2026
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@A-nnonymous A-nnonymous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in kernel in most case, but need polish afterwards.

}
inline __device__ res_t project(acc_t acc) const {
const auto mean = static_cast<scalar_t>(acc.mean);
const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之后这里用显式类型声明,避免使用auto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (threadIdx.x == 0) {
T_ACC m1; // mean
T_ACC m2; // var
thrust::pair<T_ACC, T_ACC> res = welford_op.project(val);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在将来的修改中,尽量避免引入thrust相关的数据结构与算法,方便移植

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int64_t thread_x = static_cast<int64_t>(blockIdx.x) * block_dim_x +
static_cast<int64_t>(threadIdx.x);

int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lane_id有更高效的求取方式,优化时需要注意这部分开销

zrr1999
zrr1999 previously approved these changes Jan 5, 2026
Copy link
Member

@zrr1999 zrr1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

def rms_norm(
input: Tensor,
normalized_shape: int | Sequence[int],
normalized_shape: list[int],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用Sequence[int]吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


- op: rms_norm
args: (Tensor x, Tensor scale, float epsilon)
args: (Tensor x, Tensor scale, IntArray normalized_shape={}, double epsilon= 1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalized_shape 如果 python api 明确是 int 类型,这里可以使用 int64_t[]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

zhangbo9674
zhangbo9674 previously approved these changes Jan 5, 2026
@zrr1999 zrr1999 dismissed stale reviews from zhangbo9674 and themself via 1837a81 January 5, 2026 13:12
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhengshengning zhengshengning merged commit a1d519c into PaddlePaddle:develop Jan 6, 2026
148 of 167 checks passed
zrr1999 pushed a commit to zrr1999/Paddle that referenced this pull request Jan 8, 2026
…#77098)

* accuracy and Torch alignment

* support rms_norm behavior to be the same as torch

* fix rms_norm_xpu_kernel

* add valueError_test

* Revert "add valueError_test"

This reverts commit ccaaa1b.

* Reapply "add valueError_test"

This reverts commit 19513e8.

* optimize performance

* add vectorization

* fix

* fix dtype of normalized_shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants