Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -3427,20 +3427,34 @@ bool RmsNormOpInferSymbolicShape(
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &scale_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
std::vector<int64_t> normalized_shape =
paddle::dialect::details::GetVectorAttr<int64_t>(op, "normalized_shape");

std::vector<symbol::DimExpr> x_dims = x_shape_or_data.shape();
// NOTE(large-tensor): tensor indices are small integers
int begin_norm_axis = static_cast<int>(x_dims.size() - 1);
int x_dims_size = x_dims.size();
int normalized_shape_size = normalized_shape.size();
int begin_norm_axis = x_dims_size - normalized_shape_size;

// Flatten x_dims to 2D and get dim[1]
symbol::DimExpr matrix_dim_1 = x_dims[begin_norm_axis];
for (std::size_t i = begin_norm_axis + 1; i < x_dims.size(); ++i) {
matrix_dim_1 = matrix_dim_1 * x_dims[i];
PADDLE_ENFORCE_LT(normalized_shape_size,
x_dims_size,
"normalized_shape must be less than x_dims");
for (int i = 0; i < normalized_shape_size; i++) {
infer_context->AddEqualCstr(
x_dims[x_dims_size - i - 1],
symbol::DimExpr(normalized_shape[normalized_shape_size - i - 1]));
}

if (!scale_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
std::vector<symbol::DimExpr> scale_dims = scale_shape_or_data.shape();
infer_context->AddEqualCstr(scale_dims[0], matrix_dim_1);
PADDLE_ENFORCE_EQ(
scale_dims.size(),
normalized_shape_size,
"scale_dims.size() must be equal to normalized_shape_size");
for (int i = 0; i < normalized_shape_size; i++) {
infer_context->AddEqualCstr(scale_dims[i],
symbol::DimExpr(normalized_shape[i]));
}
}

// Set output shapes
Expand Down
31 changes: 9 additions & 22 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1671,28 +1671,15 @@ void FusedRmsNormQuantGradInferMeta(const MetaTensor& x,
}
}

PADDLE_API void RMSNormGradInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& invvar,
const MetaTensor& y_grad,
float epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad) {
PADDLE_ENFORCE_EQ(
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::FLOAT16 ||
x.dtype() == DataType::BFLOAT16,
true,
common::errors::InvalidArgument(
"The dtype of x must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]",
x.dtype()));
PADDLE_ENFORCE_EQ(
scale.dtype() == DataType::FLOAT32 ||
scale.dtype() == DataType::FLOAT16 ||
scale.dtype() == DataType::BFLOAT16,
true,
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
"FLOAT16 or BFLOAT16, but got [%s]",
scale.dtype()));
PADDLE_API void RMSNormGradInferMeta(
const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& invvar,
const MetaTensor& y_grad,
const std::vector<int64_t>& normalized_shape,
double epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad) {
if (x_grad && x) {
x_grad->share_meta(x);
}
Expand Down
16 changes: 9 additions & 7 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,13 +629,15 @@ PADDLE_API void FusedRmsNormQuantGradInferMeta(const MetaTensor& x,
MetaTensor* norm_weight_grad,
MetaTensor* norm_bias_grad);

PADDLE_API void RMSNormGradInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& invvar,
const MetaTensor& y_grad,
float epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad);
PADDLE_API void RMSNormGradInferMeta(
const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& invvar,
const MetaTensor& y_grad,
const std::vector<int64_t>& normalized_shape,
double epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad);

PADDLE_API void RnnGradInferMeta(
const MetaTensor& x,
Expand Down
89 changes: 62 additions & 27 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3885,36 +3885,68 @@ void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x,

void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
float epsilon,
const std::vector<int64_t>& normalized_shape,
double epsilon,
MetaTensor* y,
MetaTensor* invvar) {
auto x_dim = x.dims();
auto x_ndim = x_dim.size();
// std::vector<int64_t> normalized_shape_data = normalized_shape.GetData();
int normalized_shape_size = normalized_shape.size();
int x_dims_size = x_dim.size();
int begin_norm_axis = x_dims_size - normalized_shape_size;

auto matrix_dim = common::flatten_to_2d(x_dim, x_ndim - 1);
PADDLE_ENFORCE_GT(begin_norm_axis,
0,
common::errors::InvalidArgument(
"'begin_norm_axis' in Op(LayerNorm) should be "
"greater than zero. But received [%d].",
begin_norm_axis));

PADDLE_ENFORCE_LT(
begin_norm_axis,
x_dims_size,
common::errors::InvalidArgument(
"'begin_norm_axis' must be less than the dimensions of X,"
"But received 'begin_norm_axis' is [%d],"
"received the dimensions of X is [%d].",
begin_norm_axis,
x_dims_size));

for (int i = 0; i < normalized_shape_size; i++) {
PADDLE_ENFORCE_EQ(x_dim[x_dims_size - i - 1],
normalized_shape[normalized_shape_size - i - 1],
common::errors::InvalidArgument(
"The %d-th dimension of X is not equal to the %d-th "
"dimension of NormalizedShape.",
x_dims_size - i - 1,
normalized_shape_size - i - 1));
}

int64_t right = matrix_dim[1];
if (scale) {
PADDLE_ENFORCE_EQ(scale.dims().size(),
1,
auto scale_dim = scale.dims();
PADDLE_ENFORCE_EQ(scale_dim.size(),
normalized_shape_size,
common::errors::InvalidArgument(
"The dimensions of Input(Scale) must be 1, but "
"received dimensions of "
"Input(Scale) is [%d]",
scale.dims().size()));
"The dimensions of Input(Scale) must be equal to the "
"dimensions of NormalizedShape. "
"But received: the dimensions of Input(Scale) is "
"[%d], the dimensions of NormalizedShape is [%d].",
scale_dim.size(),
normalized_shape_size));
for (int i = 0; i < normalized_shape_size; i++) {
PADDLE_ENFORCE_EQ(scale_dim[i],
normalized_shape[i],
common::errors::InvalidArgument(
"The %d-th dimension of Input(Scale) is not equal "
"to the %d-th dimension of NormalizedShape.",
i,
i));
}
}

PADDLE_ENFORCE_EQ(
scale.dims()[0],
right,
common::errors::InvalidArgument(
"The first dimension value of Input(Scale) must equal to be the "
"second dimension value of the flattened 2D matrix of Input(X), "
"But received the first dimension value of Input(Scale) is "
"[%d], the second dimension value of the flattened 2D matrix of "
" Input(Scale) is [%d].",
scale.dims()[0],
right));
auto matrix_dim = common::flatten_to_2d(x_dim, begin_norm_axis);
auto before_norm_dims = slice_ddim(x_dim, 0, begin_norm_axis);
int64_t right = matrix_dim[1];

PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
true,
Expand All @@ -3923,13 +3955,16 @@ void RmsNormInferMeta(const MetaTensor& x,
"0.0 and 0.001, But received [%s].",
epsilon));

phi::DataType scale_dtype = scale.dtype();
DataType x_dtype = x.dtype();
y->set_dims(x_dim);
y->set_dtype(scale_dtype);

auto row_shape = slice_ddim(x_dim, 0, x_dim.size() - 1);
invvar->set_dims({row_shape});
invvar->set_dtype(paddle::DataType::FLOAT32);
y->set_dtype(x_dtype);

DataType param_type =
(x_dtype == DataType::BFLOAT16 || x_dtype == DataType::FLOAT16)
? DataType::FLOAT32
: x_dtype;
invvar->set_dims({before_norm_dims});
invvar->set_dtype(param_type);
}

void RowConvInferMeta(const MetaTensor& x,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,8 @@ PADDLE_API void ReduceAsInferMeta(const MetaTensor& x,

PADDLE_API void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
float epsilon,
const std::vector<int64_t>& normalized_shape,
double epsilon,
MetaTensor* y,
MetaTensor* invvar);

Expand Down
116 changes: 3 additions & 113 deletions paddle/phi/kernels/gpu/rms_norm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,129 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h" // NOLINT

#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h" // NOLINT

namespace phi {

static void GetRowsCols(const std::vector<int64_t> &shape,
int *p_rows,
int *p_cols) {
int rows = 1;
for (int i = 0; i + 1 < shape.size(); ++i) {
rows *= shape[i];
}
int cols = shape[shape.size() - 1];
*p_rows = rows;
*p_cols = cols;
}

template <typename T, typename Context>
void RMSNormFwdKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &scale,
float epsilon,
DenseTensor *y,
DenseTensor *invvar) {
const auto &scale_shape = scale.dims();
int rows, cols;
GetRowsCols(common::vectorize(x.dims()), &rows, &cols);
if (scale.dtype() == phi::DataType::BFLOAT16) {
dev_ctx.template Alloc<phi::bfloat16>(y);
} else if (scale.dtype() == phi::DataType::FLOAT32) {
dev_ctx.template Alloc<float>(y);
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]",
scale.dtype()));
}
invvar->Resize({rows});
dev_ctx.template Alloc<float>(invvar);
cuda_rms_norm<T, Context>(dev_ctx, x, scale, rows, cols, epsilon, y, invvar);
}

template <typename T, typename Context>
void RMSNormBwdKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &scale,
const DenseTensor &invvar,
const DenseTensor &y_grad,
float epsilon,
DenseTensor *x_grad,
DenseTensor *scale_grad) {
int rows, cols;
GetRowsCols(common::vectorize(x.dims()), &rows, &cols);
dev_ctx.template Alloc<T>(x_grad);
if (scale_grad) {
if (scale.dtype() == phi::DataType::BFLOAT16) {
dev_ctx.template Alloc<phi::bfloat16>(scale_grad);
} else if (scale.dtype() == phi::DataType::FLOAT32) {
dev_ctx.template Alloc<float>(scale_grad);
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]",
scale.dtype()));
}
cuda_rms_norm_gradient<T, Context>(dev_ctx,
x,
scale,
invvar,
y_grad,
rows,
cols,
epsilon,
x_grad,
scale_grad);
} else {
// lora specific
if (scale.dtype() == phi::DataType::BFLOAT16) {
DenseTensor scale_grad_tmp =
EmptyLike<phi::bfloat16, Context>(dev_ctx, scale);
cuda_rms_norm_gradient<T, Context>(dev_ctx,
x,
scale,
invvar,
y_grad,
rows,
cols,
epsilon,
x_grad,
&scale_grad_tmp);
} else if (scale.dtype() == phi::DataType::FLOAT32) {
DenseTensor scale_grad_tmp = EmptyLike<float, Context>(dev_ctx, scale);
cuda_rms_norm_gradient<T, Context>(dev_ctx,
x,
scale,
invvar,
y_grad,
rows,
cols,
epsilon,
x_grad,
&scale_grad_tmp);
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]",
scale.dtype()));
}
}
}

} // namespace phi

PD_REGISTER_KERNEL(rms_norm,
GPU,
ALL_LAYOUT,
phi::RMSNormFwdKernel,
float,
double,
phi::float16,
phi::bfloat16) {}

PD_REGISTER_KERNEL(rms_norm_grad,
Expand All @@ -142,4 +31,5 @@ PD_REGISTER_KERNEL(rms_norm_grad,
phi::RMSNormBwdKernel,
float,
double,
phi::float16,
phi::bfloat16) {}
Loading
Loading