Skip to content

Commit 69974ff

Browse files
zhengshengningzrr1999
authored andcommitted
Optimize the Cuda Kernel performance of Paddle rms_norm (PaddlePaddle#77098)
* accuracy and Torch alignment * support rms_norm behavior to be the same as torch * fix rms_norm_xpu_kernel * add valueError_test * Revert "add valueError_test" This reverts commit ccaaa1b. * Reapply "add valueError_test" This reverts commit 19513e8. * optimize performance * add vectorization * fix * fix dtype of normalized_shape
1 parent 1b700c1 commit 69974ff

File tree

12 files changed

+1169
-1147
lines changed

12 files changed

+1169
-1147
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3428,20 +3428,34 @@ bool RmsNormOpInferSymbolicShape(
34283428
infer_context->GetShapeOrDataForValue(op->operand_source(0));
34293429
const auto &scale_shape_or_data =
34303430
infer_context->GetShapeOrDataForValue(op->operand_source(1));
3431+
std::vector<int64_t> normalized_shape =
3432+
paddle::dialect::details::GetVectorAttr<int64_t>(op, "normalized_shape");
34313433

34323434
std::vector<symbol::DimExpr> x_dims = x_shape_or_data.shape();
3433-
// NOTE(large-tensor): tensor indices are small integers
3434-
int begin_norm_axis = static_cast<int>(x_dims.size() - 1);
3435+
int x_dims_size = x_dims.size();
3436+
int normalized_shape_size = normalized_shape.size();
3437+
int begin_norm_axis = x_dims_size - normalized_shape_size;
34353438

34363439
// Flatten x_dims to 2D and get dim[1]
3437-
symbol::DimExpr matrix_dim_1 = x_dims[begin_norm_axis];
3438-
for (std::size_t i = begin_norm_axis + 1; i < x_dims.size(); ++i) {
3439-
matrix_dim_1 = matrix_dim_1 * x_dims[i];
3440+
PADDLE_ENFORCE_LT(normalized_shape_size,
3441+
x_dims_size,
3442+
"normalized_shape must be less than x_dims");
3443+
for (int i = 0; i < normalized_shape_size; i++) {
3444+
infer_context->AddEqualCstr(
3445+
x_dims[x_dims_size - i - 1],
3446+
symbol::DimExpr(normalized_shape[normalized_shape_size - i - 1]));
34403447
}
34413448

34423449
if (!scale_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
34433450
std::vector<symbol::DimExpr> scale_dims = scale_shape_or_data.shape();
3444-
infer_context->AddEqualCstr(scale_dims[0], matrix_dim_1);
3451+
PADDLE_ENFORCE_EQ(
3452+
scale_dims.size(),
3453+
normalized_shape_size,
3454+
"scale_dims.size() must be equal to normalized_shape_size");
3455+
for (int i = 0; i < normalized_shape_size; i++) {
3456+
infer_context->AddEqualCstr(scale_dims[i],
3457+
symbol::DimExpr(normalized_shape[i]));
3458+
}
34453459
}
34463460

34473461
// Set output shapes

paddle/phi/infermeta/backward.cc

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,28 +1671,15 @@ void FusedRmsNormQuantGradInferMeta(const MetaTensor& x,
16711671
}
16721672
}
16731673

1674-
PADDLE_API void RMSNormGradInferMeta(const MetaTensor& x,
1675-
const MetaTensor& scale,
1676-
const MetaTensor& invvar,
1677-
const MetaTensor& y_grad,
1678-
float epsilon,
1679-
MetaTensor* x_grad,
1680-
MetaTensor* scale_grad) {
1681-
PADDLE_ENFORCE_EQ(
1682-
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::FLOAT16 ||
1683-
x.dtype() == DataType::BFLOAT16,
1684-
true,
1685-
common::errors::InvalidArgument(
1686-
"The dtype of x must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]",
1687-
x.dtype()));
1688-
PADDLE_ENFORCE_EQ(
1689-
scale.dtype() == DataType::FLOAT32 ||
1690-
scale.dtype() == DataType::FLOAT16 ||
1691-
scale.dtype() == DataType::BFLOAT16,
1692-
true,
1693-
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
1694-
"FLOAT16 or BFLOAT16, but got [%s]",
1695-
scale.dtype()));
1674+
PADDLE_API void RMSNormGradInferMeta(
1675+
const MetaTensor& x,
1676+
const MetaTensor& scale,
1677+
const MetaTensor& invvar,
1678+
const MetaTensor& y_grad,
1679+
const std::vector<int64_t>& normalized_shape,
1680+
double epsilon,
1681+
MetaTensor* x_grad,
1682+
MetaTensor* scale_grad) {
16961683
if (x_grad && x) {
16971684
x_grad->share_meta(x);
16981685
}

paddle/phi/infermeta/backward.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -629,13 +629,15 @@ PADDLE_API void FusedRmsNormQuantGradInferMeta(const MetaTensor& x,
629629
MetaTensor* norm_weight_grad,
630630
MetaTensor* norm_bias_grad);
631631

632-
PADDLE_API void RMSNormGradInferMeta(const MetaTensor& x,
633-
const MetaTensor& scale,
634-
const MetaTensor& invvar,
635-
const MetaTensor& y_grad,
636-
float epsilon,
637-
MetaTensor* x_grad,
638-
MetaTensor* scale_grad);
632+
PADDLE_API void RMSNormGradInferMeta(
633+
const MetaTensor& x,
634+
const MetaTensor& scale,
635+
const MetaTensor& invvar,
636+
const MetaTensor& y_grad,
637+
const std::vector<int64_t>& normalized_shape,
638+
double epsilon,
639+
MetaTensor* x_grad,
640+
MetaTensor* scale_grad);
639641

640642
PADDLE_API void RnnGradInferMeta(
641643
const MetaTensor& x,

paddle/phi/infermeta/binary.cc

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3843,36 +3843,68 @@ void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x,
38433843

38443844
void RmsNormInferMeta(const MetaTensor& x,
38453845
const MetaTensor& scale,
3846-
float epsilon,
3846+
const std::vector<int64_t>& normalized_shape,
3847+
double epsilon,
38473848
MetaTensor* y,
38483849
MetaTensor* invvar) {
38493850
auto x_dim = x.dims();
3850-
auto x_ndim = x_dim.size();
3851+
// std::vector<int64_t> normalized_shape_data = normalized_shape.GetData();
3852+
int normalized_shape_size = normalized_shape.size();
3853+
int x_dims_size = x_dim.size();
3854+
int begin_norm_axis = x_dims_size - normalized_shape_size;
38513855

3852-
auto matrix_dim = common::flatten_to_2d(x_dim, x_ndim - 1);
3856+
PADDLE_ENFORCE_GT(begin_norm_axis,
3857+
0,
3858+
common::errors::InvalidArgument(
3859+
"'begin_norm_axis' in Op(LayerNorm) should be "
3860+
"greater than zero. But received [%d].",
3861+
begin_norm_axis));
3862+
3863+
PADDLE_ENFORCE_LT(
3864+
begin_norm_axis,
3865+
x_dims_size,
3866+
common::errors::InvalidArgument(
3867+
"'begin_norm_axis' must be less than the dimensions of X,"
3868+
"But received 'begin_norm_axis' is [%d],"
3869+
"received the dimensions of X is [%d].",
3870+
begin_norm_axis,
3871+
x_dims_size));
3872+
3873+
for (int i = 0; i < normalized_shape_size; i++) {
3874+
PADDLE_ENFORCE_EQ(x_dim[x_dims_size - i - 1],
3875+
normalized_shape[normalized_shape_size - i - 1],
3876+
common::errors::InvalidArgument(
3877+
"The %d-th dimension of X is not equal to the %d-th "
3878+
"dimension of NormalizedShape.",
3879+
x_dims_size - i - 1,
3880+
normalized_shape_size - i - 1));
3881+
}
38533882

3854-
int64_t right = matrix_dim[1];
38553883
if (scale) {
3856-
PADDLE_ENFORCE_EQ(scale.dims().size(),
3857-
1,
3884+
auto scale_dim = scale.dims();
3885+
PADDLE_ENFORCE_EQ(scale_dim.size(),
3886+
normalized_shape_size,
38583887
common::errors::InvalidArgument(
3859-
"The dimensions of Input(Scale) must be 1, but "
3860-
"received dimensions of "
3861-
"Input(Scale) is [%d]",
3862-
scale.dims().size()));
3888+
"The dimensions of Input(Scale) must be equal to the "
3889+
"dimensions of NormalizedShape. "
3890+
"But received: the dimensions of Input(Scale) is "
3891+
"[%d], the dimensions of NormalizedShape is [%d].",
3892+
scale_dim.size(),
3893+
normalized_shape_size));
3894+
for (int i = 0; i < normalized_shape_size; i++) {
3895+
PADDLE_ENFORCE_EQ(scale_dim[i],
3896+
normalized_shape[i],
3897+
common::errors::InvalidArgument(
3898+
"The %d-th dimension of Input(Scale) is not equal "
3899+
"to the %d-th dimension of NormalizedShape.",
3900+
i,
3901+
i));
3902+
}
38633903
}
38643904

3865-
PADDLE_ENFORCE_EQ(
3866-
scale.dims()[0],
3867-
right,
3868-
common::errors::InvalidArgument(
3869-
"The first dimension value of Input(Scale) must equal to be the "
3870-
"second dimension value of the flattened 2D matrix of Input(X), "
3871-
"But received the first dimension value of Input(Scale) is "
3872-
"[%d], the second dimension value of the flattened 2D matrix of "
3873-
" Input(Scale) is [%d].",
3874-
scale.dims()[0],
3875-
right));
3905+
auto matrix_dim = common::flatten_to_2d(x_dim, begin_norm_axis);
3906+
auto before_norm_dims = slice_ddim(x_dim, 0, begin_norm_axis);
3907+
int64_t right = matrix_dim[1];
38763908

38773909
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
38783910
true,
@@ -3881,13 +3913,16 @@ void RmsNormInferMeta(const MetaTensor& x,
38813913
"0.0 and 0.001, But received [%s].",
38823914
epsilon));
38833915

3884-
phi::DataType scale_dtype = scale.dtype();
3916+
DataType x_dtype = x.dtype();
38853917
y->set_dims(x_dim);
3886-
y->set_dtype(scale_dtype);
3887-
3888-
auto row_shape = slice_ddim(x_dim, 0, x_dim.size() - 1);
3889-
invvar->set_dims({row_shape});
3890-
invvar->set_dtype(paddle::DataType::FLOAT32);
3918+
y->set_dtype(x_dtype);
3919+
3920+
DataType param_type =
3921+
(x_dtype == DataType::BFLOAT16 || x_dtype == DataType::FLOAT16)
3922+
? DataType::FLOAT32
3923+
: x_dtype;
3924+
invvar->set_dims({before_norm_dims});
3925+
invvar->set_dtype(param_type);
38913926
}
38923927

38933928
void RowConvInferMeta(const MetaTensor& x,

paddle/phi/infermeta/binary.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,8 @@ PADDLE_API void ReduceAsInferMeta(const MetaTensor& x,
749749

750750
PADDLE_API void RmsNormInferMeta(const MetaTensor& x,
751751
const MetaTensor& scale,
752-
float epsilon,
752+
const std::vector<int64_t>& normalized_shape,
753+
double epsilon,
753754
MetaTensor* y,
754755
MetaTensor* invvar);
755756

paddle/phi/kernels/gpu/rms_norm_cuda_kernel.cu

Lines changed: 3 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -11,130 +11,18 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
#include <cassert>
15-
#include <vector>
16-
#include "paddle/phi/core/dense_tensor.h"
17-
#include "paddle/phi/kernels/empty_kernel.h" // NOLINT
1814

15+
#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h"
1916
#include "paddle/phi/backends/gpu/gpu_context.h"
20-
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
2117
#include "paddle/phi/core/kernel_registry.h"
22-
#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h" // NOLINT
23-
24-
namespace phi {
25-
26-
static void GetRowsCols(const std::vector<int64_t> &shape,
27-
int *p_rows,
28-
int *p_cols) {
29-
int rows = 1;
30-
for (int i = 0; i + 1 < shape.size(); ++i) {
31-
rows *= shape[i];
32-
}
33-
int cols = shape[shape.size() - 1];
34-
*p_rows = rows;
35-
*p_cols = cols;
36-
}
37-
38-
template <typename T, typename Context>
39-
void RMSNormFwdKernel(const Context &dev_ctx,
40-
const DenseTensor &x,
41-
const DenseTensor &scale,
42-
float epsilon,
43-
DenseTensor *y,
44-
DenseTensor *invvar) {
45-
const auto &scale_shape = scale.dims();
46-
int rows, cols;
47-
GetRowsCols(common::vectorize(x.dims()), &rows, &cols);
48-
if (scale.dtype() == phi::DataType::BFLOAT16) {
49-
dev_ctx.template Alloc<phi::bfloat16>(y);
50-
} else if (scale.dtype() == phi::DataType::FLOAT32) {
51-
dev_ctx.template Alloc<float>(y);
52-
} else {
53-
PADDLE_THROW(common::errors::InvalidArgument(
54-
"The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]",
55-
scale.dtype()));
56-
}
57-
invvar->Resize({rows});
58-
dev_ctx.template Alloc<float>(invvar);
59-
cuda_rms_norm<T, Context>(dev_ctx, x, scale, rows, cols, epsilon, y, invvar);
60-
}
61-
62-
template <typename T, typename Context>
63-
void RMSNormBwdKernel(const Context &dev_ctx,
64-
const DenseTensor &x,
65-
const DenseTensor &scale,
66-
const DenseTensor &invvar,
67-
const DenseTensor &y_grad,
68-
float epsilon,
69-
DenseTensor *x_grad,
70-
DenseTensor *scale_grad) {
71-
int rows, cols;
72-
GetRowsCols(common::vectorize(x.dims()), &rows, &cols);
73-
dev_ctx.template Alloc<T>(x_grad);
74-
if (scale_grad) {
75-
if (scale.dtype() == phi::DataType::BFLOAT16) {
76-
dev_ctx.template Alloc<phi::bfloat16>(scale_grad);
77-
} else if (scale.dtype() == phi::DataType::FLOAT32) {
78-
dev_ctx.template Alloc<float>(scale_grad);
79-
} else {
80-
PADDLE_THROW(common::errors::InvalidArgument(
81-
"The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]",
82-
scale.dtype()));
83-
}
84-
cuda_rms_norm_gradient<T, Context>(dev_ctx,
85-
x,
86-
scale,
87-
invvar,
88-
y_grad,
89-
rows,
90-
cols,
91-
epsilon,
92-
x_grad,
93-
scale_grad);
94-
} else {
95-
// lora specific
96-
if (scale.dtype() == phi::DataType::BFLOAT16) {
97-
DenseTensor scale_grad_tmp =
98-
phi::EmptyLike<phi::bfloat16, Context>(dev_ctx, scale);
99-
cuda_rms_norm_gradient<T, Context>(dev_ctx,
100-
x,
101-
scale,
102-
invvar,
103-
y_grad,
104-
rows,
105-
cols,
106-
epsilon,
107-
x_grad,
108-
&scale_grad_tmp);
109-
} else if (scale.dtype() == phi::DataType::FLOAT32) {
110-
DenseTensor scale_grad_tmp =
111-
phi::EmptyLike<float, Context>(dev_ctx, scale);
112-
cuda_rms_norm_gradient<T, Context>(dev_ctx,
113-
x,
114-
scale,
115-
invvar,
116-
y_grad,
117-
rows,
118-
cols,
119-
epsilon,
120-
x_grad,
121-
&scale_grad_tmp);
122-
} else {
123-
PADDLE_THROW(common::errors::InvalidArgument(
124-
"The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]",
125-
scale.dtype()));
126-
}
127-
}
128-
}
129-
130-
} // namespace phi
13118

13219
PD_REGISTER_KERNEL(rms_norm,
13320
GPU,
13421
ALL_LAYOUT,
13522
phi::RMSNormFwdKernel,
13623
float,
13724
double,
25+
phi::float16,
13826
phi::bfloat16) {}
13927

14028
PD_REGISTER_KERNEL(rms_norm_grad,
@@ -143,4 +31,5 @@ PD_REGISTER_KERNEL(rms_norm_grad,
14331
phi::RMSNormBwdKernel,
14432
float,
14533
double,
34+
phi::float16,
14635
phi::bfloat16) {}

0 commit comments

Comments
 (0)