diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 233a8de05dbacb..09bce252716bf3 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -3427,20 +3427,34 @@ bool RmsNormOpInferSymbolicShape( infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &scale_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(1)); + std::vector normalized_shape = + paddle::dialect::details::GetVectorAttr(op, "normalized_shape"); std::vector x_dims = x_shape_or_data.shape(); - // NOTE(large-tensor): tensor indices are small integers - int begin_norm_axis = static_cast(x_dims.size() - 1); + int x_dims_size = x_dims.size(); + int normalized_shape_size = normalized_shape.size(); + int begin_norm_axis = x_dims_size - normalized_shape_size; // Flatten x_dims to 2D and get dim[1] - symbol::DimExpr matrix_dim_1 = x_dims[begin_norm_axis]; - for (std::size_t i = begin_norm_axis + 1; i < x_dims.size(); ++i) { - matrix_dim_1 = matrix_dim_1 * x_dims[i]; + PADDLE_ENFORCE_LT(normalized_shape_size, + x_dims_size, + "normalized_shape must be less than x_dims"); + for (int i = 0; i < normalized_shape_size; i++) { + infer_context->AddEqualCstr( + x_dims[x_dims_size - i - 1], + symbol::DimExpr(normalized_shape[normalized_shape_size - i - 1])); } if (!scale_shape_or_data.isa()) { std::vector scale_dims = scale_shape_or_data.shape(); - infer_context->AddEqualCstr(scale_dims[0], matrix_dim_1); + PADDLE_ENFORCE_EQ( + scale_dims.size(), + normalized_shape_size, + "scale_dims.size() must be equal to normalized_shape_size"); + for (int i = 0; i < normalized_shape_size; i++) { + infer_context->AddEqualCstr(scale_dims[i], + symbol::DimExpr(normalized_shape[i])); + } } // Set output shapes diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 63a168c3474c95..62838603f8db4e 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1671,28 +1671,15 @@ void FusedRmsNormQuantGradInferMeta(const MetaTensor& x, } } -PADDLE_API void RMSNormGradInferMeta(const MetaTensor& x, - const MetaTensor& scale, - const MetaTensor& invvar, - const MetaTensor& y_grad, - float epsilon, - MetaTensor* x_grad, - MetaTensor* scale_grad) { - PADDLE_ENFORCE_EQ( - x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::FLOAT16 || - x.dtype() == DataType::BFLOAT16, - true, - common::errors::InvalidArgument( - "The dtype of x must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]", - x.dtype())); - PADDLE_ENFORCE_EQ( - scale.dtype() == DataType::FLOAT32 || - scale.dtype() == DataType::FLOAT16 || - scale.dtype() == DataType::BFLOAT16, - true, - common::errors::InvalidArgument("The dtype of scale must be FLOAT32, " - "FLOAT16 or BFLOAT16, but got [%s]", - scale.dtype())); +PADDLE_API void RMSNormGradInferMeta( + const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& invvar, + const MetaTensor& y_grad, + const std::vector& normalized_shape, + double epsilon, + MetaTensor* x_grad, + MetaTensor* scale_grad) { if (x_grad && x) { x_grad->share_meta(x); } diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 92336f9380626b..7e51094f498b4f 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -629,13 +629,15 @@ PADDLE_API void FusedRmsNormQuantGradInferMeta(const MetaTensor& x, MetaTensor* norm_weight_grad, MetaTensor* norm_bias_grad); -PADDLE_API void RMSNormGradInferMeta(const MetaTensor& x, - const MetaTensor& scale, - const MetaTensor& invvar, - const MetaTensor& y_grad, - float epsilon, - MetaTensor* x_grad, - MetaTensor* scale_grad); +PADDLE_API void RMSNormGradInferMeta( + const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& invvar, + const MetaTensor& y_grad, + const std::vector& normalized_shape, + double epsilon, + MetaTensor* x_grad, + MetaTensor* scale_grad); PADDLE_API void RnnGradInferMeta( const MetaTensor& x, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 69081c339bee43..850a86630640bb 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -3885,36 +3885,68 @@ void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x, void RmsNormInferMeta(const MetaTensor& x, const MetaTensor& scale, - float epsilon, + const std::vector& normalized_shape, + double epsilon, MetaTensor* y, MetaTensor* invvar) { auto x_dim = x.dims(); - auto x_ndim = x_dim.size(); + // std::vector normalized_shape_data = normalized_shape.GetData(); + int normalized_shape_size = normalized_shape.size(); + int x_dims_size = x_dim.size(); + int begin_norm_axis = x_dims_size - normalized_shape_size; - auto matrix_dim = common::flatten_to_2d(x_dim, x_ndim - 1); + PADDLE_ENFORCE_GT(begin_norm_axis, + 0, + common::errors::InvalidArgument( + "'begin_norm_axis' in Op(LayerNorm) should be " + "greater than zero. But received [%d].", + begin_norm_axis)); + + PADDLE_ENFORCE_LT( + begin_norm_axis, + x_dims_size, + common::errors::InvalidArgument( + "'begin_norm_axis' must be less than the dimensions of X," + "But received 'begin_norm_axis' is [%d]," + "received the dimensions of X is [%d].", + begin_norm_axis, + x_dims_size)); + + for (int i = 0; i < normalized_shape_size; i++) { + PADDLE_ENFORCE_EQ(x_dim[x_dims_size - i - 1], + normalized_shape[normalized_shape_size - i - 1], + common::errors::InvalidArgument( + "The %d-th dimension of X is not equal to the %d-th " + "dimension of NormalizedShape.", + x_dims_size - i - 1, + normalized_shape_size - i - 1)); + } - int64_t right = matrix_dim[1]; if (scale) { - PADDLE_ENFORCE_EQ(scale.dims().size(), - 1, + auto scale_dim = scale.dims(); + PADDLE_ENFORCE_EQ(scale_dim.size(), + normalized_shape_size, common::errors::InvalidArgument( - "The dimensions of Input(Scale) must be 1, but " - "received dimensions of " - "Input(Scale) is [%d]", - scale.dims().size())); + "The dimensions of Input(Scale) must be equal to the " + "dimensions of NormalizedShape. " + "But received: the dimensions of Input(Scale) is " + "[%d], the dimensions of NormalizedShape is [%d].", + scale_dim.size(), + normalized_shape_size)); + for (int i = 0; i < normalized_shape_size; i++) { + PADDLE_ENFORCE_EQ(scale_dim[i], + normalized_shape[i], + common::errors::InvalidArgument( + "The %d-th dimension of Input(Scale) is not equal " + "to the %d-th dimension of NormalizedShape.", + i, + i)); + } } - PADDLE_ENFORCE_EQ( - scale.dims()[0], - right, - common::errors::InvalidArgument( - "The first dimension value of Input(Scale) must equal to be the " - "second dimension value of the flattened 2D matrix of Input(X), " - "But received the first dimension value of Input(Scale) is " - "[%d], the second dimension value of the flattened 2D matrix of " - " Input(Scale) is [%d].", - scale.dims()[0], - right)); + auto matrix_dim = common::flatten_to_2d(x_dim, begin_norm_axis); + auto before_norm_dims = slice_ddim(x_dim, 0, begin_norm_axis); + int64_t right = matrix_dim[1]; PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true, @@ -3923,13 +3955,16 @@ void RmsNormInferMeta(const MetaTensor& x, "0.0 and 0.001, But received [%s].", epsilon)); - phi::DataType scale_dtype = scale.dtype(); + DataType x_dtype = x.dtype(); y->set_dims(x_dim); - y->set_dtype(scale_dtype); - - auto row_shape = slice_ddim(x_dim, 0, x_dim.size() - 1); - invvar->set_dims({row_shape}); - invvar->set_dtype(paddle::DataType::FLOAT32); + y->set_dtype(x_dtype); + + DataType param_type = + (x_dtype == DataType::BFLOAT16 || x_dtype == DataType::FLOAT16) + ? DataType::FLOAT32 + : x_dtype; + invvar->set_dims({before_norm_dims}); + invvar->set_dtype(param_type); } void RowConvInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index ae87d84b9a94e7..f1e9464094243a 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -775,7 +775,8 @@ PADDLE_API void ReduceAsInferMeta(const MetaTensor& x, PADDLE_API void RmsNormInferMeta(const MetaTensor& x, const MetaTensor& scale, - float epsilon, + const std::vector& normalized_shape, + double epsilon, MetaTensor* y, MetaTensor* invvar); diff --git a/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.cu b/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.cu index b524c92d7be7b3..82bfa0985a7e20 100644 --- a/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.cu +++ b/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.cu @@ -11,122 +11,10 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/empty_kernel.h" // NOLINT +#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h" // NOLINT - -namespace phi { - -static void GetRowsCols(const std::vector &shape, - int *p_rows, - int *p_cols) { - int rows = 1; - for (int i = 0; i + 1 < shape.size(); ++i) { - rows *= shape[i]; - } - int cols = shape[shape.size() - 1]; - *p_rows = rows; - *p_cols = cols; -} - -template -void RMSNormFwdKernel(const Context &dev_ctx, - const DenseTensor &x, - const DenseTensor &scale, - float epsilon, - DenseTensor *y, - DenseTensor *invvar) { - const auto &scale_shape = scale.dims(); - int rows, cols; - GetRowsCols(common::vectorize(x.dims()), &rows, &cols); - if (scale.dtype() == phi::DataType::BFLOAT16) { - dev_ctx.template Alloc(y); - } else if (scale.dtype() == phi::DataType::FLOAT32) { - dev_ctx.template Alloc(y); - } else { - PADDLE_THROW(common::errors::InvalidArgument( - "The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]", - scale.dtype())); - } - invvar->Resize({rows}); - dev_ctx.template Alloc(invvar); - cuda_rms_norm(dev_ctx, x, scale, rows, cols, epsilon, y, invvar); -} - -template -void RMSNormBwdKernel(const Context &dev_ctx, - const DenseTensor &x, - const DenseTensor &scale, - const DenseTensor &invvar, - const DenseTensor &y_grad, - float epsilon, - DenseTensor *x_grad, - DenseTensor *scale_grad) { - int rows, cols; - GetRowsCols(common::vectorize(x.dims()), &rows, &cols); - dev_ctx.template Alloc(x_grad); - if (scale_grad) { - if (scale.dtype() == phi::DataType::BFLOAT16) { - dev_ctx.template Alloc(scale_grad); - } else if (scale.dtype() == phi::DataType::FLOAT32) { - dev_ctx.template Alloc(scale_grad); - } else { - PADDLE_THROW(common::errors::InvalidArgument( - "The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]", - scale.dtype())); - } - cuda_rms_norm_gradient(dev_ctx, - x, - scale, - invvar, - y_grad, - rows, - cols, - epsilon, - x_grad, - scale_grad); - } else { - // lora specific - if (scale.dtype() == phi::DataType::BFLOAT16) { - DenseTensor scale_grad_tmp = - EmptyLike(dev_ctx, scale); - cuda_rms_norm_gradient(dev_ctx, - x, - scale, - invvar, - y_grad, - rows, - cols, - epsilon, - x_grad, - &scale_grad_tmp); - } else if (scale.dtype() == phi::DataType::FLOAT32) { - DenseTensor scale_grad_tmp = EmptyLike(dev_ctx, scale); - cuda_rms_norm_gradient(dev_ctx, - x, - scale, - invvar, - y_grad, - rows, - cols, - epsilon, - x_grad, - &scale_grad_tmp); - } else { - PADDLE_THROW(common::errors::InvalidArgument( - "The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]", - scale.dtype())); - } - } -} - -} // namespace phi PD_REGISTER_KERNEL(rms_norm, GPU, @@ -134,6 +22,7 @@ PD_REGISTER_KERNEL(rms_norm, phi::RMSNormFwdKernel, float, double, + phi::float16, phi::bfloat16) {} PD_REGISTER_KERNEL(rms_norm_grad, @@ -142,4 +31,5 @@ PD_REGISTER_KERNEL(rms_norm_grad, phi::RMSNormBwdKernel, float, double, + phi::float16, phi::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h b/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h index 6b7cade23cb62a..1584ff29e54c6f 100644 --- a/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h +++ b/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h @@ -14,1025 +14,1098 @@ #pragma once -#include "paddle/common/exception.h" +#include + +#include "paddle/common/ddim.h" +#include "paddle/common/flags.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" -#include // NOLINT -#include // NOLINT +COMMON_DECLARE_bool(use_accuracy_compatible_kernel); namespace phi { -#define DEFAULT_THROW(NAME, TYPE) \ - default: \ - do { \ - PD_THROW(#NAME, " not implemented for '", TYPE, "'"); \ - } while (0); \ - break - -#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch (TYPEIN) { \ - case float: { \ - using scalar_t_in = float; \ - switch (TYPEOUT) { \ - case float: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - DEFAULT_THROW(NAME, TYPEOUT); \ - } \ - break; \ - } \ - DEFAULT_THROW(NAME, TYPEIN); \ - } -#define WARP_SIZE 32 +// ----------------------------------------------------------------------- +// Constants +// ----------------------------------------------------------------------- + +static constexpr int kCUDANumThreads = 256; +static constexpr int kCUDABlockReduceNumThreads = 512; +static constexpr int kWarpSize = 32; + +// ----------------------------------------------------------------------- +// Helper Functions & Structs +// ----------------------------------------------------------------------- template -__device__ __forceinline__ T WARP_SHFL_XOR(T value, - int laneMask, - int width = WARP_SIZE, - unsigned int mask = 0xffffffff) { - return __shfl_xor_sync(mask, value, laneMask, width); +__device__ __forceinline__ T Rsqrt_(T x); + +template <> +__device__ __forceinline__ float Rsqrt_(float x) { + return rsqrtf(x); } -template -__device__ __forceinline__ T WARP_SHFL(T value, - int srcLane, - int width = WARP_SIZE, - unsigned int mask = 0xffffffff) { - return __shfl_sync(mask, value, srcLane, width); +template <> +__device__ __forceinline__ double Rsqrt_(double x) { + return rsqrt(x); } -template -__device__ void cuWelfordOnlineSum_(const U curr, - U& mu, // NOLINT - U& sigma2, // NOLINT - U& count) { // NOLINT - count = count + U(1); - U delta = curr - mu; - U lmean = mu + delta / count; - mu = lmean; - U delta2 = curr - lmean; - sigma2 = sigma2 + delta * delta2; +template +struct alignas(sizeof(T) * kVecSize) aligned_vector { + T val[kVecSize]; +}; + +template +struct SimplePair { + T1 first; + T2 second; + + __host__ __device__ SimplePair() {} + __host__ __device__ SimplePair(T1 f, T2 s) : first(f), second(s) {} +}; + +template +bool can_vectorize(const T* ptr, int alignment) { + uint64_t addr = reinterpret_cast(ptr); + return addr % alignment == 0; } -template -__device__ void cuChanOnlineSum_(const U muB, - const U sigma2B, - const U countB, - U& mu, // NOLINT - U& sigma2, // NOLINT - U& count) { // NOLINT - U delta = muB - mu; - U nA = count; - U nB = countB; - count = count + countB; - U nX = count; - if (nX > U(0)) { - nA = nA / nX; - nB = nB / nX; - mu = nA * mu + nB * muB; - sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; - } else { - mu = U(0); - sigma2 = U(0); - } +// ----------------------------------------------------------------------- +// Welford Algorithms +// ----------------------------------------------------------------------- +template +struct WelfordData { + scalar_t mean; + scalar_t m2; + index_t n; + scalar_t nf; + + __host__ __device__ WelfordData() : mean(0), m2(0), n(0), nf(0) {} + + __host__ __device__ + WelfordData(scalar_t mean, scalar_t m2, index_t n, scalar_t nf) + : mean(mean), m2(m2), n(n), nf(nf) {} +}; + +// ----------------------------------------------------------------------- +// Warp & Block Reductions +// ----------------------------------------------------------------------- + +template +__device__ __forceinline__ T WARP_SHFL_DOWN_(T value, + int delta, + int width = kWarpSize, + unsigned int mask = 0xffffffff) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_down_sync(mask, value, delta, width); +#else + return __shfl_down(value, delta, width); +#endif } -template -__device__ void cuRMSOnlineSum_(const U curr, U& sigma2) { // NOLINT - sigma2 = sigma2 + curr * curr; +template +__device__ __forceinline__ T WARP_SHFL_(T value, + int srcLane, + int width = kWarpSize, + unsigned int mask = 0xffffffff) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_sync(mask, value, srcLane, width); +#else + return __shfl(value, srcLane, width); +#endif } -template -__device__ void cuChanRMSOnlineSum_(const U sigma2B, U& sigma2) { // NOLINT - sigma2 = sigma2 + sigma2B; +template +__device__ __forceinline__ T WARP_SHFL_XOR_(T value, + int laneMask, + int width = kWarpSize, + unsigned int mask = 0xffffffff) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif } -template -__device__ void cuWelfordMuSigma2_(const T* __restrict__ vals, - const int n1, - const int n2, - const int i1, - U& mu, // NOLINT - U& sigma2, // NOLINT - U* buf, - bool rms_only) { - // Assumptions: - // 1) blockDim.x == WARP_SIZE - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - U count = U(0); - mu = U(0); - sigma2 = U(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const T* lvals = vals + i1 * n2; - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - U curr = static_cast(lvals[l + k]); - if (!rms_only) { - cuWelfordOnlineSum_(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum_(curr, sigma2); - } - } - } - for (; l < n2; ++l) { - U curr = static_cast(lvals[l]); - if (!rms_only) { - cuWelfordOnlineSum_(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum_(curr, sigma2); - } - } - // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - U sigma2B = WARP_SHFL(sigma2, srcLaneB); - if (!rms_only) { - U muB = WARP_SHFL(mu, srcLaneB); - U countB = WARP_SHFL(count, srcLaneB); - cuChanOnlineSum_(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum_(sigma2B, sigma2); - } - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - U* ubuf = (U*)buf; // NOLINT - U* ibuf = (U*)(ubuf + blockDim.y); // NOLINT - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - if (!rms_only) { - ubuf[2 * wrt_y] = mu; - ibuf[wrt_y] = count; - } - ubuf[2 * wrt_y + 1] = sigma2; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - U sigma2B = ubuf[2 * threadIdx.y + 1]; - if (!rms_only) { - U muB = ubuf[2 * threadIdx.y]; - U countB = ibuf[threadIdx.y]; - cuChanOnlineSum_(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum_(sigma2B, sigma2); - } - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - ubuf[0] = mu; - } - ubuf[1] = sigma2; - } - __syncthreads(); - if (!rms_only) { - mu = ubuf[0]; - } - sigma2 = ubuf[1] / U(n2); - // don't care about final value of count, we know count == n2 - } else { - if (!rms_only) { - mu = WARP_SHFL(mu, 0); - } - mu = WARP_SHFL(mu, 0); - sigma2 = WARP_SHFL(sigma2 / U(n2), 0); +template +__device__ T BlockReduceSum(T val, T* shared) { + int lane = threadIdx.x % kWarpSize; + int wid = threadIdx.x / kWarpSize; + + for (int offset = kWarpSize >> 1; offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN_(val, offset); + } + + if (lane == 0) { + shared[wid] = val; + } + __syncthreads(); + + // Assuming blockDim.x <= 1024, max 32 warps + val = (threadIdx.x < blockDim.x / kWarpSize) ? shared[lane] : T(0); + + if (wid == 0) { + for (int offset = kWarpSize >> 1; offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN_(val, offset); } } + return val; } -template <> -__device__ void cuWelfordMuSigma2_(const phi::float16* __restrict__ vals, - const int n1, - const int n2, - const int i1, - float& mu, // NOLINT - float& sigma2, // NOLINT - float* buf, - bool rms_only) { - // Assumptions: - // 1) blockDim.x == WARP_SIZE - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - float count = 0.0f; - mu = float(0); // NOLINT - sigma2 = float(0); // NOLINT - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const auto* lvals = vals + i1 * n2; - int l = 8 * thrx; - if ((((size_t)lvals) & 3) != 0) { // NOLINT - // 16 bit alignment - // first thread consumes first point - if (thrx == 0) { - float curr = static_cast(lvals[0]); - if (!rms_only) { - cuWelfordOnlineSum_(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum_(curr, sigma2); - } - } - ++l; - } - // at this point, lvals[l] are 32 bit aligned for all threads. - for (; l + 7 < n2; l += 8 * numx) { - for (int k = 0; k < 8; k += 2) { - float2 curr = __half22float2(*((__half2*)(lvals + l + k))); // NOLINT - if (!rms_only) { - cuWelfordOnlineSum_(curr.x, mu, sigma2, count); - cuWelfordOnlineSum_(curr.y, mu, sigma2, count); - } else { - cuRMSOnlineSum_(curr.x, sigma2); - cuRMSOnlineSum_(curr.y, sigma2); - } - } - } - for (; l < n2; ++l) { - float curr = static_cast(lvals[l]); - if (!rms_only) { - cuWelfordOnlineSum_(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum_(curr, sigma2); - } +template +struct WelfordOps { + acc_scalar_t correction; + bool take_sqrt; + + public: + using acc_t = WelfordData; + inline __device__ acc_t reduce(acc_t acc, + scalar_t data, + index_t /*idx*/) const { + index_t new_n = acc.n + 1; + acc_scalar_t new_nf = static_cast(new_n); + acc_scalar_t delta = data - acc.mean; + acc_scalar_t new_mean = acc.mean + delta / new_nf; + acc_scalar_t new_delta = data - new_mean; + return { + new_mean, + acc.m2 + delta * new_delta, + new_n, + new_nf, + }; + } + inline __device__ acc_t combine(acc_t a, acc_t b) const { + if (a.nf == 0) { + return b; } - // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - float sigma2B = WARP_SHFL(sigma2, srcLaneB); - if (!rms_only) { - float muB = WARP_SHFL(mu, srcLaneB); - float countB = WARP_SHFL(count, srcLaneB); - cuChanOnlineSum_(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum_(sigma2B, sigma2); - } + if (b.nf == 0) { + return a; } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - float* ubuf = (float*)buf; // NOLINT - float* ibuf = (float*)(ubuf + blockDim.y); // NOLINT - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2 * wrt_y + 1] = sigma2; - if (!rms_only) { - ubuf[2 * wrt_y] = mu; - ibuf[wrt_y] = count; - } - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - float sigma2B = ubuf[2 * threadIdx.y + 1]; - if (!rms_only) { - float muB = ubuf[2 * threadIdx.y]; - float countB = ibuf[threadIdx.y]; - cuChanOnlineSum_(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum_(sigma2B, sigma2); - } - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - ubuf[0] = mu; - } - ubuf[1] = sigma2; - } - __syncthreads(); - if (!rms_only) { - mu = ubuf[0]; - } - sigma2 = ubuf[1] / float(n2); // NOLINT - // don't care about final value of count, we know count == n2 - } else { - if (!rms_only) { - mu = WARP_SHFL(mu, 0); - } - sigma2 = WARP_SHFL(sigma2 / float(n2), 0); // NOLINT + acc_scalar_t delta = b.mean - a.mean; + acc_scalar_t new_count = a.nf + b.nf; + acc_scalar_t nb_over_n = b.nf / new_count; + return {a.mean + delta * nb_over_n, + a.m2 + b.m2 + delta * delta * a.nf * nb_over_n, + -1, + new_count}; + } + inline __device__ res_t project(acc_t acc) const { + const scalar_t mean = static_cast(acc.mean); + const acc_scalar_t divisor = acc.nf > correction ? acc.nf - correction : 0; + const acc_scalar_t var = acc.m2 / divisor; + res_t results(take_sqrt ? std::sqrt(var) : var, mean); + return results; + } + + static __device__ acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const { + return {WARP_SHFL_DOWN_(acc.mean, offset), + WARP_SHFL_DOWN_(acc.m2, offset), + WARP_SHFL_DOWN_(acc.n, offset), + WARP_SHFL_DOWN_(acc.nf, offset)}; + } +#endif + __host__ __device__ WelfordOps(acc_scalar_t correction, bool take_sqrt) + : correction(correction), take_sqrt(take_sqrt) {} +}; + +// ----------------------------------------------------------------------- +// Forward Kernels +// ----------------------------------------------------------------------- + +// Non-vectorized RowwiseMoments for RMSNorm +template +__global__ void RowwiseMomentsCUDAKernel(int64_t N, + T_ACC eps, + const T* X, + T_ACC* rstd) { + using WelfordType = WelfordData; + using WelfordOp = WelfordOps>; + + const int64_t i = blockIdx.x; + WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; + WelfordType val(0, 0, 0, 0); + + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { + const int64_t index = i * N + j; + val = welford_op.reduce(val, static_cast(X[index]), index); + } + + // Block Reduce + // 1. Warp Reduce + for (int offset = kWarpSize >> 1; offset > 0; offset >>= 1) { + WelfordType wdB = welford_op.warp_shfl_down(val, offset); + val = welford_op.combine(val, wdB); + } + + // 2. Block Reduce (via shared memory) + __shared__ + typename std::aligned_storage::type val_shared[32]; + WelfordType* val_shared_ptr = reinterpret_cast(val_shared); + + int lane = threadIdx.x % kWarpSize; + int wid = threadIdx.x / kWarpSize; + + __syncthreads(); + if (lane == 0) { + val_shared_ptr[wid] = val; + } + __syncthreads(); + + val = (threadIdx.x < blockDim.x / kWarpSize) ? val_shared_ptr[lane] + : WelfordType(0, 0, 0, 0); + + // Final Warp Reduce for the first warp + if (wid == 0) { + for (int offset = kWarpSize >> 1; offset > 0; offset >>= 1) { + WelfordType wdB = welford_op.warp_shfl_down(val, offset); + val = welford_op.combine(val, wdB); } } -} -template -__inline__ __device__ U rsqrt(U v) { - return U(1) / sqrt(v); -} -template <> -__inline__ __device__ float rsqrt(float v) { - return rsqrtf(v); + if (threadIdx.x == 0) { + T_ACC m1; // mean + T_ACC m2; // var + SimplePair res = welford_op.project(val); + m2 = res.first; + m1 = res.second; + rstd[i] = Rsqrt_(m2 + m1 * m1 + eps); + } } -template <> -__inline__ __device__ double rsqrt(double v) { - return rsqrt(v); + +// Non-vectorized Forward for RMSNorm +template +__global__ void RMSNormForwardCUDAKernel( + int64_t N, const T* X, const T_ACC* rstd, const T* scale, T* Y) { + const int64_t i = blockIdx.x; + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { + const int64_t index = i * N + j; + const T_ACC scale_v = + scale == nullptr ? T_ACC(1) : static_cast(scale[j]); + Y[index] = static_cast((static_cast(X[index])) * + static_cast(rstd[i]) * scale_v); + } } -namespace { // NOLINT -// This is the un-specialized struct. Note that we prevent instantiation of -// this struct by putting an undefined symbol in the function body so it won't -// compile. -// template -// struct SharedMemory -// { -// // Ensure that we won't compile any un-specialized types -// __device__ T *getPointer() -// { -// extern __device__ void error(void); -// error(); -// return NULL; -// } -// }; -// https://github.com/NVIDIA/apex/issues/246 -template -struct SharedMemory; +// Vectorized Helper +template +__device__ T_ACC compute_stats(const T* __restrict__ X, + const int N, + T_ACC* buf) { + using vec_t = aligned_vector; + const vec_t* X_vec = reinterpret_cast(X); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const int n_vec_to_read = N / kVecSize; + T_ACC sigma2 = 0; -template <> -struct SharedMemory { - __device__ float* getPointer() { - extern __shared__ float s_float[]; - return s_float; + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; +#pragma unroll + for (int ii = 0; ii < kVecSize; ii++) { + T_ACC val = static_cast(data.val[ii]); + sigma2 += val * val; + } } -}; -} // namespace - -template -__device__ void cuApplyLayerNorm_(V* __restrict__ output_vals, - U* __restrict__ mean, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma, - const V* __restrict__ beta, - bool rms_only) { - // Assumptions: - // 1) blockDim.x == WARP_SIZE - // 2) Tensors are contiguous - // - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { - SharedMemory shared; - U* buf = shared.getPointer(); - U mu, sigma2; - cuWelfordMuSigma2_(vals, n1, n2, i1, mu, sigma2, buf, rms_only); - const T* lvals = vals + i1 * n2; - V* ovals = output_vals + i1 * n2; - U c_invvar = rsqrt(sigma2 + epsilon); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL && (beta != NULL || rms_only)) { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - if (!rms_only) { - ovals[i] = - static_cast(static_cast(gamma[i]) * c_invvar * (curr - mu) + - static_cast(beta[i])); - } else { - ovals[i] = static_cast(static_cast(gamma[i]) * c_invvar * curr); - } + // Intra-warp reduction + for (int offset = (kWarpSize >> 1); offset > 0; offset >>= 1) { + sigma2 += WARP_SHFL_DOWN_(sigma2, offset); + } + + // Inter-warp reductions + if (blockDim.y > 1) { + T_ACC* meansigmabuf = buf; + // Use simpler layout: just sigma2 + for (int offset = blockDim.y >> 1; offset > 0; offset >>= 1) { + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + meansigmabuf[wrt_y] = sigma2; } - } else { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - if (!rms_only) { - ovals[i] = static_cast(c_invvar * (curr - mu)); - } else { - ovals[i] = static_cast(c_invvar * curr); - } + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y < offset) { + sigma2 += meansigmabuf[threadIdx.y]; } + __syncthreads(); } if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - mean[i1] = mu; - } - invvar[i1] = c_invvar; + meansigmabuf[0] = sigma2 / static_cast(N); } __syncthreads(); + return meansigmabuf[0]; + + } else { + return WARP_SHFL_(sigma2, 0) / static_cast(N); } } -template -__global__ void cuApplyRMSNorm_(V* __restrict__ output_vals, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma) { - cuApplyLayerNorm_( - output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); -} +template +__global__ void vectorized_rms_norm_kernel(const int N, + T_ACC eps, + const T* __restrict__ X, + const T* scale, + T_ACC* rstd, + T* Y) { + extern __shared__ char s_data_raw[]; + T_ACC* s_data = reinterpret_cast(s_data_raw); -template -__device__ void cuLoadWriteStridedInputs_(const int i1_block, - const int thr_load_row_off, - const int thr_load_col_off, - const int i2_off, - const int row_stride, - U* warp_buf1, - U* warp_buf2, - const T* input, - const V* dout, - const int i1_end, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - bool rms_only) { - int i1 = i1_block + thr_load_row_off; - if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1 * n2 + i2; - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - if (!rms_only) { - warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = - curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf2[write_idx] = curr_dout * (curr_input)*curr_invvar; - } - } else { - if (!rms_only) { - warp_buf1[write_idx] = U(0); - } - warp_buf2[write_idx] = U(0); + auto i1 = blockIdx.x; + const T* block_row = X + i1 * N; + + // Compute stats + T_ACC sigma2 = compute_stats(block_row, N, s_data); + + using vec_t = aligned_vector; + const vec_t* X_vec = reinterpret_cast(block_row); + const vec_t* scale_vec = + (scale != nullptr) ? reinterpret_cast(scale) : nullptr; + vec_t* Y_vec = reinterpret_cast(Y + i1 * N); + + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const int n_vec_to_read = N / kVecSize; + + T_ACC rstd_val = Rsqrt_(sigma2 + eps); + + if (scale_vec != nullptr) { + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; + vec_t out; +#pragma unroll + for (int ii = 0; ii < kVecSize; ii++) { + out.val[ii] = + static_cast(static_cast(scale_vec[i].val[ii]) * + (rstd_val * static_cast(data.val[ii]))); } + Y_vec[i] = out; } } else { - for (int k = 0; k < blockDim.y; ++k) { - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (!rms_only) { - warp_buf1[write_idx] = U(0); + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; + vec_t out; +#pragma unroll + for (int ii = 0; ii < kVecSize; ii++) { + out.val[ii] = + static_cast(rstd_val * static_cast(data.val[ii])); } - warp_buf2[write_idx] = U(0); + Y_vec[i] = out; } } + + if (thrx == 0) { + rstd[i1] = rstd_val; + } } -template -__device__ void cuLoadAddStridedInputs_(const int i1_block, - const int thr_load_row_off, - const int thr_load_col_off, - const int i2_off, - const int row_stride, - U* warp_buf1, - U* warp_buf2, - const T* input, - const V* dout, - const int i1_end, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - bool rms_only) { - int i1 = i1_block + thr_load_row_off; - if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1 * n2 + i2; - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - if (!rms_only) { - warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += - curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf2[write_idx] += curr_dout * (curr_input)*curr_invvar; - } - } +template +void launch_vectorized_rms_norm_kernel_driver(int N, + int64_t M, + T_ACC eps, + const T* X_data, + const T* scale_data, + T* Y_data, + T_ACC* rstd_data, + cudaStream_t stream) { + const int num_threads = 128; + const dim3 threads(kWarpSize, num_threads / kWarpSize, 1); + dim3 blocks(M); + + // Shared memory for reduction: need size proportional to threads.y and T_ACC + int nshared = threads.y > 1 ? threads.y * 3 / 2 * sizeof(T_ACC) : 0; + + vectorized_rms_norm_kernel + <<>>( + N, eps, X_data, scale_data, rstd_data, Y_data); +} + +// ----------------------------------------------------------------------- +// Backward Kernels +// ----------------------------------------------------------------------- + +template +__device__ __inline__ void compute_gI(const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + const T* __restrict__ scale, + T* dX, + const int N, + T_ACC* buf) { + const auto i1 = blockIdx.x; + const T_ACC rstd_val = rstd[i1]; + T_ACC stats_x2{0}; + constexpr int unroll = 4; + auto l = unroll * threadIdx.x; + const T* X_i = X + i1 * N; + const T* dY_i = dY + i1 * N; + T* dX_i = dX + i1 * N; + + for (; l + unroll - 1 < N; l += blockDim.x * unroll) { +#pragma unroll + for (int k = 0; k < unroll; k++) { + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l + k]) : T_ACC(1); + const auto c_h = static_cast(X_i[l + k]); + const auto c_loss = static_cast(dY_i[l + k]); + stats_x2 += c_loss * scale_val * (c_h)*rstd_val; } } -} + for (; l < N; l++) { + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l]) : T_ACC(1); + const auto c_h = static_cast(X_i[l]); + const auto c_loss = static_cast(dY_i[l]); + stats_x2 += c_loss * scale_val * (c_h)*rstd_val; + } -template -__global__ void cuComputePartGradGammaBeta_(const V* __restrict__ dout, - const T* __restrict__ input, - const int n1, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - U epsilon, - U* part_grad_gamma, - U* part_grad_beta, - bool rms_only) { - const int numsegs_n1 = - (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); - const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; - const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; - const int i1_beg_plus_one = - (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; - const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; - const int row_stride = blockDim.x + 1; - const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); - const int thr_load_row_off = - (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; - const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; - SharedMemory shared; - U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * - // blockDim.y + (blockDim.y - - // 1)*(blockDim.x/blockDim.y) elements - U* warp_buf1 = (U*)buf; // NOLINT - U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; - // compute partial sums from strided inputs - // do this to increase number of loads in flight - cuLoadWriteStridedInputs_(i1_beg, - thr_load_row_off, - thr_load_col_off, - i2_off, - row_stride, - warp_buf1, - warp_buf2, - input, - dout, - i1_end, - n2, - mean, - invvar, - rms_only); - for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; - i1_block += blockDim.y * blockDim.y) { - cuLoadAddStridedInputs_(i1_block, - thr_load_row_off, - thr_load_col_off, - i2_off, - row_stride, - warp_buf1, - warp_buf2, - input, - dout, - i1_end, - n2, - mean, - invvar, - rms_only); + stats_x2 = BlockReduceSum(stats_x2, buf); + + if (threadIdx.x == 0) { + buf[0] = stats_x2; } __syncthreads(); - // inter-warp reductions - // sum within each warp - U acc1 = U(0); - U acc2 = U(0); - for (int k = 0; k < blockDim.y; ++k) { - int row1 = threadIdx.y + k * blockDim.y; - int idx1 = row1 * row_stride + threadIdx.x; - if (!rms_only) { - acc1 += warp_buf1[idx1]; + stats_x2 = buf[0]; + + T_ACC fH = N; + T_ACC term1 = (T_ACC(1) / fH) * rstd_val; + + for (int l = threadIdx.x; l < N; l += blockDim.x) { + const auto x = static_cast(X_i[l]); + const auto dy = static_cast(dY_i[l]); + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l]) : T_ACC(1); + + T_ACC f_grad_input = fH * scale_val * dy; + f_grad_input -= (x)*rstd_val * stats_x2; + f_grad_input *= term1; + dX_i[l] = static_cast(f_grad_input); + } +} + +template +__global__ void rms_norm_grad_input_kernel(const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + const T* __restrict__ scale, + T* dX, + const int N) { + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* buf = reinterpret_cast(&s_data1); + compute_gI(dY, X, rstd, scale, dX, N, buf); +} + +template +__global__ void rms_norm_grad_input_kernel_vectorized( + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + const T* __restrict__ scale, + T* dX, + const int N) { + alignas(sizeof(double)) extern __shared__ char shared_data[]; + T_ACC* reduce_buf = reinterpret_cast(&shared_data); + + const auto bIdx = blockIdx.x; + const T_ACC rstd_val = rstd[bIdx]; + const T* X_i = X + bIdx * N; + const T* dY_i = dY + bIdx * N; + T* dX_i = dX + bIdx * N; + + using vec_t = aligned_vector; + const vec_t* const X_i_vec_ptr = reinterpret_cast(X_i); + const vec_t* const dY_i_vec_ptr = reinterpret_cast(dY_i); + const vec_t* const scale_vec_ptr = + (scale != nullptr) ? reinterpret_cast(scale) : nullptr; + vec_t* const dX_i_vec = reinterpret_cast(dX_i); + + vec_t X_i_vec_reg, dY_i_vec_reg, scale_vec_reg, dX_i_vec_reg; + for (int k = 0; k < kVecSize; ++k) { + scale_vec_reg.val[k] = T(1); + } + + T_ACC stats_x2{0}; + unsigned int l = threadIdx.x * kVecSize; + for (; l + kVecSize - 1 < N; l += blockDim.x * kVecSize) { + unsigned int vec_idx = l / kVecSize; + if (scale != nullptr) { + scale_vec_reg = scale_vec_ptr[vec_idx]; + } + + X_i_vec_reg = X_i_vec_ptr[vec_idx]; + dY_i_vec_reg = dY_i_vec_ptr[vec_idx]; + + for (int k = 0; k < kVecSize; ++k) { + const auto scale_val = static_cast(scale_vec_reg.val[k]); + const auto c_h = static_cast(X_i_vec_reg.val[k]); + const auto c_loss = static_cast(dY_i_vec_reg.val[k]); + stats_x2 += c_loss * scale_val * (c_h)*rstd_val; } - acc2 += warp_buf2[idx1]; } - if (!rms_only) { - warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + // Tail Loop + for (; l < N; l++) { + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l]) : T_ACC(1); + const auto c_h = static_cast(X_i[l]); + const auto c_loss = static_cast(dY_i[l]); + stats_x2 += c_loss * scale_val * (c_h)*rstd_val; + } + + stats_x2 = BlockReduceSum(stats_x2, reduce_buf); + if (threadIdx.x == 0) { + reduce_buf[0] = stats_x2; } - warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; __syncthreads(); - // sum all warps - for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { - if (threadIdx.y < offset) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + offset; - int idx1 = row1 * row_stride + threadIdx.x; - int idx2 = row2 * row_stride + threadIdx.x; - if (!rms_only) { - warp_buf1[idx1] += warp_buf1[idx2]; - } - warp_buf2[idx1] += warp_buf2[idx2]; + stats_x2 = reduce_buf[0]; + + T_ACC fH = N; + T_ACC term1 = (T_ACC(1) / fH) * rstd_val; + + l = threadIdx.x * kVecSize; + for (; l + kVecSize - 1 < N; l += blockDim.x * kVecSize) { + unsigned int vec_idx = l / kVecSize; + if (scale != nullptr) { + scale_vec_reg = scale_vec_ptr[vec_idx]; } - __syncthreads(); - } - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (threadIdx.y == 0 && i2 < n2) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + 1; - int idx1 = row1 * row_stride + threadIdx.x; - int idx2 = row2 * row_stride + threadIdx.x; - if (!rms_only) { - part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + + X_i_vec_reg = X_i_vec_ptr[vec_idx]; + dY_i_vec_reg = dY_i_vec_ptr[vec_idx]; + + for (int k = 0; k < kVecSize; ++k) { + const auto scale_val = static_cast(scale_vec_reg.val[k]); + const auto x = static_cast(X_i_vec_reg.val[k]); + const auto dy = static_cast(dY_i_vec_reg.val[k]); + + T_ACC f_grad_input = fH * scale_val * dy; + f_grad_input -= (x)*rstd_val * stats_x2; + f_grad_input *= term1; + dX_i_vec_reg.val[k] = static_cast(f_grad_input); } - part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + + dX_i_vec[vec_idx] = dX_i_vec_reg; + } + + // Tail Loop + for (; l < N; l += blockDim.x) { + const auto x = static_cast(X_i[l]); + const auto dy = static_cast(dY_i[l]); + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l]) : T_ACC(1); + + T_ACC f_grad_input = fH * scale_val * dy; + f_grad_input -= (x)*rstd_val * stats_x2; + f_grad_input *= term1; + dX_i[l] = static_cast(f_grad_input); } } -template -__global__ void cuComputeGradGammaBeta_(const U* part_grad_gamma, - const U* part_grad_beta, - const int part_size, - const int n1, - const int n2, - V* grad_gamma, - V* grad_beta, - bool rms_only) { - // sum partial gradients for gamma and beta - SharedMemory shared; - U* buf = shared.getPointer(); - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (i2 < n2) { - // each warp does sequential reductions until reduced part_size is num_warps - int num_warp_reductions = part_size / blockDim.y; - U sum_gamma = U(0); - U sum_beta = U(0); - const U* part_grad_gamma_ptr = - part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; - const U* part_grad_beta_ptr = - part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; - for (int warp_offset = 0; warp_offset < num_warp_reductions; - ++warp_offset) { - sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; - if (!rms_only) { - sum_beta += part_grad_beta_ptr[warp_offset * n2]; - } +template +__device__ __forceinline__ void blockReduceScaleBackwardHelper( + int64_t M_start, + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + T* __restrict__ dscale, + T_ACC* dscale_sum) { + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + int64_t thread_x = static_cast(blockIdx.x) * block_dim_x + + static_cast(threadIdx.x); + + int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1); + int64_t mean_index = + M_start + static_cast(threadIdx.y) * rows_per_thread_y; + T_ACC warp_rstd = 0; + if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { + warp_rstd = rstd[mean_index + lane_id]; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + __syncwarp(); +#endif + + T_ACC dY_regs[rows_per_thread_y] = {0}; + T_ACC X_regs[rows_per_thread_y] = {0}; +#pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + int64_t current_y = + M_start + static_cast(threadIdx.y) * rows_per_thread_y + i; + bool active = true; + if (check_x && thread_x >= N) { + active = false; } - // inter-warp reductions - const int nbsize3 = blockDim.x * blockDim.y / 2; - for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { - // top half write to shared memory - if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { - const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[write_idx] = sum_gamma; - if (!rms_only) { - buf[write_idx + nbsize3] = sum_beta; - } - } - __syncthreads(); - // bottom half sums - if (threadIdx.y < offset) { - const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; - sum_gamma += buf[read_idx]; - if (!rms_only) { - sum_beta += buf[read_idx + nbsize3]; - } - } - __syncthreads(); + if (check_y && current_y >= M) { + active = false; } - // write out fully summed gradients - if (threadIdx.y == 0) { - grad_gamma[i2] = sum_gamma; - if (!rms_only) { - grad_beta[i2] = sum_beta; - } + if (active) { + dY_regs[i] = static_cast(dY[current_y * N + thread_x]); + X_regs[i] = static_cast(X[current_y * N + thread_x]); } } + +#pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + T_ACC rstd_reg = WARP_SHFL_(warp_rstd, i, kWarpSize); + *dscale_sum += dY_regs[i] * (X_regs[i]) * rstd_reg; + } } -template -__global__ void cuComputeGradInput_(const V* __restrict__ dout, - const T* __restrict__ input, - const int n1, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - U epsilon, - const V* gamma, - T* grad_input, - bool rms_only) { - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { - U sum_loss1 = U(0); - U sum_loss2 = U(0); - U c_mean; - if (!rms_only) { - c_mean = mean[i1]; +template +__device__ __forceinline__ void blockReduceScaleBackwardWithChecks( + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + T* __restrict__ dscale, + T_ACC* dscale_sum) { + for (int64_t M_start = static_cast(blockIdx.y) * rows_per_block_y; + M_start < M; + M_start += rows_per_block_y * gridDim.y) { + int64_t M_end = M_start + rows_per_block_y - 1; + if (!check_y || M_end < M) { + blockReduceScaleBackwardHelper( + M_start, M, N, dY, X, rstd, dscale, dscale_sum); + } else { + blockReduceScaleBackwardHelper( + M_start, M, N, dY, X, rstd, dscale, dscale_sum); } - const U c_invvar = invvar[i1]; - const T* k_input = input + i1 * n2; - const V* k_dout = dout + i1 * n2; - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL) { - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); - const U c_loss = static_cast(k_dout[l + k]); - const U gamma_tmp = static_cast(gamma[l + k]); - if (!rms_only) { - sum_loss1 += c_loss * gamma_tmp; - sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * gamma_tmp * (c_h)*c_invvar; - } - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - const U gamma_tmp = static_cast(gamma[l]); - if (!rms_only) { - sum_loss1 += c_loss * gamma_tmp; - sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * gamma_tmp * (c_h)*c_invvar; - } - } + } +} + +template +__global__ void ScaleBackwardCUDAKernelTemplate(int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + T* __restrict__ dscale) { + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + static_assert(rows_per_thread_y <= kWarpSize); + + T_ACC dscale_sum = 0; + + // Template : Boundary check of x and y + if (aligned_grid) { + blockReduceScaleBackwardWithChecks( + M, N, dY, X, rstd, dscale, &dscale_sum); + } else { + if (static_cast(blockIdx.x) * block_dim_x + block_dim_x - 1 < N) { + blockReduceScaleBackwardWithChecks( + M, N, dY, X, rstd, dscale, &dscale_sum); } else { - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); - const U c_loss = static_cast(k_dout[l + k]); - if (!rms_only) { - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * (c_h)*c_invvar; - } - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - if (!rms_only) { - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * (c_h)*c_invvar; - } - } + blockReduceScaleBackwardWithChecks( + M, N, dY, X, rstd, dscale, &dscale_sum); } - // intra-warp reductions - for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { - if (!rms_only) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + } + + int64_t thread_x = + (static_cast(blockIdx.x)) * block_dim_x + threadIdx.x; + + if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) { + if (aligned_grid || thread_x < N) { + int64_t thread_y = + (static_cast(blockIdx.y)) * blockDim.y + threadIdx.y; + if (dscale) { + dscale[thread_y * N + thread_x] = static_cast(dscale_sum); } - sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); } - // inter-warp reductions - if (blockDim.y > 1) { - SharedMemory shared; - U* buf = shared.getPointer(); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { - const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - if (!rms_only) { - buf[2 * wrt_i] = sum_loss1; - } - buf[2 * wrt_i + 1] = sum_loss2; - } - __syncthreads(); - // lower half merges - if (threadIdx.y < offset) { - const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - if (!rms_only) { - sum_loss1 += buf[2 * read_i]; - } - sum_loss2 += buf[2 * read_i + 1]; - } - __syncthreads(); - } - if (threadIdx.y == 0) { - if (!rms_only) { - buf[2 * threadIdx.x] = sum_loss1; - } - buf[2 * threadIdx.x + 1] = sum_loss2; - } - __syncthreads(); - if (threadIdx.y != 0) { - if (!rms_only) { - sum_loss1 = buf[2 * threadIdx.x]; - } - sum_loss2 = buf[2 * threadIdx.x + 1]; + } else { + // Full reduction using shared memory + static_assert(rows_per_thread_y <= kWarpSize); + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dscale; + int padded_bx = (block_dim_x + 1); + s_dscale = s_data_typed; + s_dscale[threadIdx.y * padded_bx + threadIdx.x] = dscale_sum; + __syncthreads(); + + static_assert(block_dim_x * block_dim_y % kWarpSize == 0); + constexpr int warps_available_to_reduce = + block_dim_x * block_dim_y / kWarpSize; + int thread_id = threadIdx.y * block_dim_x + threadIdx.x; + int warp_id = thread_id / kWarpSize; + int lane_id = thread_id & (kWarpSize - 1); +#pragma unroll + for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) { + T_ACC reg_dscale; + if (lane_id < block_dim_y) { + reg_dscale = s_dscale[lane_id * padded_bx + i]; } - } - // all threads now have the two sums over l - U fH = (U)n2; - U term1 = (U(1) / fH) * c_invvar; - T* k_grad_input = grad_input + i1 * n2; - if (gamma != NULL) { - for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss * static_cast(gamma[l]); - if (!rms_only) { - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - } else { - f_grad_input -= (c_h)*c_invvar * sum_loss2; - } - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); +#pragma unroll + for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) { + reg_dscale += WARP_SHFL_XOR_(reg_dscale, delta, kWarpSize); } - } else { - for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss; - if (!rms_only) { - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - } else { - f_grad_input -= (c_h)*c_invvar * sum_loss2; + + int64_t out_index = static_cast(blockIdx.x) * block_dim_x + i; + if (threadIdx.x == 0 && (aligned_grid || out_index < N)) { + if (dscale) { + dscale[out_index] = static_cast(reg_dscale); } - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); } } - // prevent race where buf is written again before reads are done - __syncthreads(); } } -static cudaDeviceProp GetDevicePropImpl() { - int device = -1; - PD_CHECK(cudaGetDevice(&device) == cudaSuccess); - cudaDeviceProp prop; - PD_CHECK(cudaGetDeviceProperties(&prop, device) == cudaSuccess); - return prop; -} +template +void ConfigureAndLaunchScaleBackwardKernel(const T* dY_data, + const T* X_data, + const T_ACC* rstd_data, + int64_t M, + int64_t N, + T* dscale_data, + cudaStream_t cuda_stream) { + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + blocks.y = 1; + size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2; -static cudaDeviceProp* GetDeviceProp() { - static auto prop = GetDevicePropImpl(); - return ∝ + if (blocks.y == 1 && threads.y == 1) { + if (aligned_grid) { + ScaleBackwardCUDAKernelTemplate + <<>>( + M, N, dY_data, X_data, rstd_data, dscale_data); + } else { + ScaleBackwardCUDAKernelTemplate + <<>>( + M, N, dY_data, X_data, rstd_data, dscale_data); + } + } else { + if (aligned_grid) { + ScaleBackwardCUDAKernelTemplate + <<>>( + M, N, dY_data, X_data, rstd_data, dscale_data); + } else { + ScaleBackwardCUDAKernelTemplate + <<>>( + M, N, dY_data, X_data, rstd_data, dscale_data); + } + } } -template -void HostApplyRMSNorm_(V* output, - U* invvar, - const T* input, - int n1, - int n2, - double epsilon, - const V* gamma, - cudaStream_t stream) { - // auto stream = at::cuda::getCurrentCUDAStream().stream(); - const dim3 threads(32, 4, 1); - // const uint64_t maxGridY = - // at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; - const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; - cuApplyRMSNorm_<<>>( - output, invvar, input, n1, n2, U(epsilon), gamma); -} +// ----------------------------------------------------------------------- +// Host API Implementations +// ----------------------------------------------------------------------- template -void cuda_rms_norm(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& scale, - int rows, - int cols, - float epsilon, - DenseTensor* y, - DenseTensor* invvar) { -#define DISPATCH_FWD_CASE(scalar_t_out) \ - HostApplyRMSNorm_( \ - y->data(), \ - invvar->data(), \ - const_cast(x.data()), \ - rows, \ - cols, \ - epsilon, \ - const_cast(scale.data()), \ - dev_ctx.stream()) - // scale.dtype() same as y->dtype() - if (scale.dtype() == phi::DataType::FLOAT32) { - DISPATCH_FWD_CASE(float); - } else if (scale.dtype() == phi::DataType::BFLOAT16) { - DISPATCH_FWD_CASE(phi::bfloat16); - } -#undef DISPATCH_FWD_CASE -} +void RMSNormFwdKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale_opt, + const std::vector& normalized_shape, + double epsilon, + DenseTensor* y, + DenseTensor* invvar) { + using T_ACC = typename phi::dtype::MPTypeTrait::Type; + + int begin_norm_axis = x.dims().size() - normalized_shape.size(); + + auto matrix_dim = common::flatten_to_2d(x.dims(), begin_norm_axis); + int64_t rows = matrix_dim[0]; + int64_t cols = matrix_dim[1]; + + auto* scale_ptr = scale_opt.get_ptr(); + const DenseTensor& scale = *scale_ptr; -template -void HostRMSNormGradient_(const Context& dev_ctx, - const V* dout, - const U* invvar, - const DenseTensor& input, - int n1, - int n2, - const V* gamma, - double epsilon, - T* grad_input, - V* grad_gamma, - cudaStream_t stream) { - if (gamma != NULL) { - const int part_size = 16; - const dim3 threads2(32, 4, 1); - const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); - const int nshared2_a = - 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); - const int nshared2_b = threads2.x * threads2.y * sizeof(U); - const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - auto place = input.place(); - DenseTensor part_grad_gamma = - Empty(dev_ctx, {part_size, n2}); - cuComputePartGradGammaBeta_<<>>( - dout, - input.data(), - n1, - n2, - invvar, // unused - invvar, - U(epsilon), - part_grad_gamma.data(), - part_grad_gamma.data(), /* unused */ - true); - - const dim3 threads3(32, 8, 1); - const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); - const int nshared3 = threads3.x * threads3.y * sizeof(U); - cuComputeGradGammaBeta_<<>>( - part_grad_gamma.data(), - part_grad_gamma.data(), /* unused */ - part_size, - n1, - n2, - grad_gamma, - grad_gamma, /* unused */ - true); + auto* x_data = x.data(); + auto* scale_data = scale_ptr ? scale.data() : nullptr; + auto* y_data = dev_ctx.template Alloc(y); + auto* rstd_data = dev_ctx.template Alloc(invvar); + + auto stream = dev_ctx.stream(); + + // When using a vectorization size of 8 in fp16 and bf16, there may be + // misalignment of accuracy and torch alignment. + if (!FLAGS_use_accuracy_compatible_kernel && rows <= 1024 && + (cols / rows >= 32)) { + constexpr int num_vec_elems2 = 8; + constexpr int alignment2 = num_vec_elems2 * sizeof(T); + bool can_vec_X2 = can_vectorize(x_data, alignment2); + bool can_vec_Y2 = can_vectorize(y_data, alignment2); + bool can_vec_scale2 = can_vectorize(scale_data, alignment2); + bool is_supported_type2 = (std::is_same::value || + std::is_same::value); + if (is_supported_type2 && + cols <= + static_cast(1ULL << std::numeric_limits::digits) && + cols % num_vec_elems2 == 0 && can_vec_X2 && can_vec_Y2 && + can_vec_scale2) { + launch_vectorized_rms_norm_kernel_driver( + cols, + rows, + static_cast(epsilon), + x_data, + scale_data, + y_data, + rstd_data, + stream); + return; + } } - // compute grad_input - const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; - const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); - const dim3 threads1(32, 4, 1); - int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; - cuComputeGradInput_<<>>( - dout, - input.data(), - n1, - n2, - invvar, /* unused */ - invvar, - U(epsilon), - gamma, - grad_input, - true); + // Check vectorization conditions + constexpr int num_vec_elems = 4; + constexpr int alignment = num_vec_elems * sizeof(T); + bool can_vec_X = can_vectorize(x_data, alignment); + bool can_vec_Y = can_vectorize(y_data, alignment); + bool can_vec_scale = can_vectorize(scale_data, alignment); + bool is_supported_type = (std::is_same::value || + std::is_same::value || + std::is_same::value); + + if (is_supported_type && + cols <= + static_cast(1ULL << std::numeric_limits::digits) && + cols % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_scale) { + launch_vectorized_rms_norm_kernel_driver( + cols, + rows, + static_cast(epsilon), + x_data, + scale_data, + y_data, + rstd_data, + stream); + + } else { + RowwiseMomentsCUDAKernel + <<>>( + cols, static_cast(epsilon), x_data, rstd_data); + + RMSNormForwardCUDAKernel<<>>( + cols, x_data, rstd_data, scale_data, y_data); + } } template -void cuda_rms_norm_gradient(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& invvar, - const DenseTensor& dy, - int rows, - int cols, - float epsilon, - DenseTensor* grad_x, - DenseTensor* grad_scale) { -#define DISPATCH_BWD_CASE(scalar_t_out) \ - HostRMSNormGradient_( \ - dev_ctx, \ - dy.data(), \ - invvar.data(), \ - x, \ - rows, \ - cols, \ - scale.data(), \ - epsilon, \ - grad_x->data(), \ - grad_scale->data(), \ - dev_ctx.stream()) - if (scale.dtype() == phi::DataType::FLOAT32) { - DISPATCH_BWD_CASE(float); - } else if (scale.dtype() == phi::DataType::BFLOAT16) { - DISPATCH_BWD_CASE(phi::bfloat16); +void RMSNormBwdKernel(const Context& dev_ctx, + const DenseTensor& X, + const paddle::optional& scale_opt, + const DenseTensor& invvar, + const DenseTensor& dY, + const std::vector& normalized_shape, + double epsilon, + DenseTensor* dX, + DenseTensor* dscale) { + using T_ACC = typename phi::dtype::MPTypeTrait::Type; + + int begin_norm_axis = X.dims().size() - normalized_shape.size(); + + // X, dY: [Batch, ..., Feature] -> flatten to [M, N] + // scale, dscale: [Feature] -> [N] + // invvar: [Batch, ...] -> [M] + + auto matrix_dim = common::flatten_to_2d(X.dims(), begin_norm_axis); + int64_t M = matrix_dim[0]; + int64_t N = matrix_dim[1]; + + auto* scale_ptr = scale_opt.get_ptr(); + const DenseTensor& scale = *scale_ptr; + + auto* dY_data = dY.data(); + auto* X_data = X.data(); + auto* scale_data = scale_ptr ? scale.data() : nullptr; + auto* invvar_data = invvar.data(); + + auto* dX_data = dX ? dev_ctx.template Alloc(dX) : nullptr; + auto* dscale_data = dscale ? dev_ctx.template Alloc(dscale) : nullptr; + + auto stream = dev_ctx.stream(); + + // 1. Compute dX + if (dX_data) { + static constexpr int kVecSize = 4; + bool bVectorSizeMultiple = (N % kVecSize == 0); + const unsigned int alignment = sizeof(T) * kVecSize; + bool bAlignedBuffers = can_vectorize(dY_data, alignment) && + can_vectorize(X_data, alignment) && + can_vectorize(scale_data, alignment) && + can_vectorize(dX_data, alignment); + bool is_supported_type = (std::is_same::value || + std::is_same::value || + std::is_same::value); + + const unsigned int alignment2 = sizeof(T) * 8; + bool bAlignedBuffers2 = can_vectorize(dY_data, alignment2) && + can_vectorize(X_data, alignment2) && + can_vectorize(scale_data, alignment2) && + can_vectorize(dX_data, alignment2); + bool is_supported_type2 = (std::is_same::value || + std::is_same::value); + + dim3 blocks(M); + constexpr int num_threads = 128; + constexpr int nshared = (num_threads / kWarpSize) * sizeof(T_ACC); + + // When using a vectorization size of 8 in fp16 and bf16, there may be + // misalignment of accuracy and torch alignment. + if (!FLAGS_use_accuracy_compatible_kernel && is_supported_type2 && + bAlignedBuffers2 && (N % 8 == 0 && M <= 1024 && (N / M >= 32))) { + rms_norm_grad_input_kernel_vectorized + <<>>( + dY_data, X_data, invvar_data, scale_data, dX_data, N); + } else if (is_supported_type && bAlignedBuffers && bVectorSizeMultiple) { + rms_norm_grad_input_kernel_vectorized + <<>>( + dY_data, X_data, invvar_data, scale_data, dX_data, N); + } else { + rms_norm_grad_input_kernel + <<>>( + dY_data, X_data, invvar_data, scale_data, dX_data, N); + } + } + + // 2. Compute dscale + if (scale_data) { + constexpr int block_dim_x = 32; + const int sm_count = dev_ctx.GetSMCount(); + if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) { + // When M>>N and N is very small. We can parallelize and accelerate + // computation by starting multiple blocks on the M-dimension (y). + constexpr int block_dim_y = 1; + constexpr int rows_per_block_y = 32; + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y; + constexpr int max_grid_size = 64 * 1024 / 2; + blocks.y = std::min(max_grid_size / blocks.x, blocks.y); + + DenseTensor dscale_blocks; + dscale_blocks.Resize({static_cast(blocks.y * threads.y), N}); + T* dscale_blocks_ptr = dev_ctx.template Alloc(&dscale_blocks); + + if (aligned_grid) { + ScaleBackwardCUDAKernelTemplate<<>>( + M, N, dY_data, X_data, invvar_data, dscale_blocks_ptr); + } else { + ScaleBackwardCUDAKernelTemplate<<>>( + M, N, dY_data, X_data, invvar_data, dscale_blocks_ptr); + } + + // Sum reduction along blocks.y dimension to get final dscale + phi::SumKernel( + dev_ctx, dscale_blocks, {0}, dscale->dtype(), false, dscale); + + } else { + if (M < 64) { + ConfigureAndLaunchScaleBackwardKernel( + dY_data, X_data, invvar_data, M, N, dscale_data, stream); + } else if (M < 128) { + ConfigureAndLaunchScaleBackwardKernel( + dY_data, X_data, invvar_data, M, N, dscale_data, stream); + } else if (M < 256) { + ConfigureAndLaunchScaleBackwardKernel( + dY_data, X_data, invvar_data, M, N, dscale_data, stream); + } else { + ConfigureAndLaunchScaleBackwardKernel( + dY_data, X_data, invvar_data, M, N, dscale_data, stream); + } + } } -#undef DISPATCH_BWD_CASE } } // namespace phi diff --git a/paddle/phi/kernels/xpu/rms_norm_xpu_kernel.cc b/paddle/phi/kernels/xpu/rms_norm_xpu_kernel.cc index a2a5c102357bfa..dbe1df34bebb89 100644 --- a/paddle/phi/kernels/xpu/rms_norm_xpu_kernel.cc +++ b/paddle/phi/kernels/xpu/rms_norm_xpu_kernel.cc @@ -38,10 +38,27 @@ static void GetRowsCols(const std::vector &shape, template void RMSNormFwdKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &scale, - float epsilon, + const paddle::optional &scale_opt, + const std::vector &normalized_shape, + double epsilon, DenseTensor *y, DenseTensor *invvar) { + int begin_norm_axis = x.dims().size() - normalized_shape.size(); + PADDLE_ENFORCE_EQ( + begin_norm_axis, + x.dims().size() - 1, + common::errors::InvalidArgument( + "XPU RMSNorm only supports begin_norm_axis=%d, but got %d", + x.dims().size() - 1, + begin_norm_axis)); + + auto *scale_ptr = scale_opt.get_ptr(); + if (scale_ptr == nullptr) { + PADDLE_THROW(common::errors::InvalidArgument( + "Scale must be provided for RMSNorm backward")); + } + const DenseTensor &scale = *scale_ptr; + int64_t rows, cols; GetRowsCols(common::vectorize(x.dims()), &rows, &cols); @@ -123,12 +140,29 @@ void RMSNormFwdKernel(const Context &dev_ctx, template void RMSNormBwdKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &scale, + const paddle::optional &scale_opt, const DenseTensor &invvar, const DenseTensor &y_grad, - float epsilon, + const std::vector &normalized_shape, + double epsilon, DenseTensor *x_grad, DenseTensor *scale_grad) { + int begin_norm_axis = x.dims().size() - normalized_shape.size(); + PADDLE_ENFORCE_EQ( + begin_norm_axis, + x.dims().size() - 1, + common::errors::InvalidArgument( + "XPU RMSNorm only supports begin_norm_axis=%d, but got %d", + x.dims().size() - 1, + begin_norm_axis)); + + auto *scale_ptr = scale_opt.get_ptr(); + if (scale_ptr == nullptr) { + PADDLE_THROW(common::errors::InvalidArgument( + "Scale must be provided for RMSNorm backward")); + } + const DenseTensor &scale = *scale_ptr; + int64_t rows, cols; GetRowsCols(common::vectorize(x.dims()), &rows, &cols); dev_ctx.template Alloc(x_grad); diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 962bf9b655cf16..0cb4d705ec6e00 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -4239,14 +4239,15 @@ data_type: w - backward_op: rms_norm_grad - forward: rms_norm (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) - args: (Tensor x, Tensor scale,Tensor invvar, Tensor y_grad, float epsilon) + forward: rms_norm (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon = 1e-5) -> Tensor(y), Tensor(invvar) + args: (Tensor x, Tensor scale, Tensor invvar, Tensor y_grad, int64_t[] normalized_shape={}, double epsilon = 1e-5) output: Tensor(x_grad), Tensor(scale_grad) infer_meta: func: RMSNormGradInferMeta kernel: func: rms_norm_grad data_type: x + optional : scale - backward_op: shuffle_batch_grad forward: shuffle_batch (Tensor x, Tensor seed, int startup_seed=0) -> Tensor(out), Tensor(shuffle_idx), Tensor(seed_out) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 0200bc487a55b8..6184e9b50441fa 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -6109,12 +6109,13 @@ traits : paddle::dialect::ForwardOnlyTrait - op: rms_norm - args: (Tensor x, Tensor scale, float epsilon) + args: (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon= 1e-5) output: Tensor(y), Tensor(invvar) infer_meta: func: RmsNormInferMeta kernel: func: rms_norm data_type: x + optional : scale backward: rms_norm_grad interfaces : paddle::dialect::InferSymbolicShapeInterface diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index fa59c9aad14d67..4e0f88c25707b2 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -463,7 +463,7 @@ def layer_norm( def rms_norm( input: Tensor, - normalized_shape: int | Sequence[int], + normalized_shape: Sequence[int], weight: Tensor | None = None, eps: float = 1e-5, name: str | None = None, @@ -473,7 +473,7 @@ def rms_norm( Args: input (Tensor): Input tensor of shape [rows, cols] or higher dimensions (flattened to 2D). - normalized_shape(int|list|tuple): Input shape from an expected input of + normalized_shape(list|tuple): Input shape from an expected input of size :math:`[*, normalized_shape[0], normalized_shape[1], ..., normalized_shape[-1]]`. If it is a single integer, this module will normalize over the last dimension which is expected to be of that specific size. @@ -485,46 +485,9 @@ def rms_norm( out (Tensor): Normalized tensor of same shape as input. invvar (Tensor): Tensor of shape [rows], the inverse standard deviation of each row. """ - input_shape = list(input.shape) - input_ndim = len(input_shape) - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = [normalized_shape] - elif isinstance(normalized_shape, tuple): - normalized_shape = list(normalized_shape) - elif not isinstance(normalized_shape, list): - raise ValueError( - "`normalized_shape` should be int, list of ints or tuple of ints." - ) - - normalized_ndim = len(normalized_shape) - begin_norm_axis = input_ndim - normalized_ndim - if input_ndim < normalized_ndim or ( - not paddle.utils.is_same_shape( - input_shape[begin_norm_axis:], normalized_shape - ) - ): - str_normalized_shape = str(normalized_shape) - raise ValueError( - 'Given normalized_shape is ' - + str_normalized_shape - + ', expected input with shape [*, ' - + str_normalized_shape[1:] - + ', but got input shape ' - + str(input_shape) - ) - - if normalized_ndim != 1: - raise ValueError( - 'Given len(normalized_shape) is ' - + normalized_ndim - + ', expected len(normalized_shape) is 1.' - ) - - if weight is None: - raise ValueError("weight must not be None.") if in_dynamic_or_pir_mode(): - return _C_ops.rms_norm(input, weight, eps) + return _C_ops.rms_norm(input, weight, normalized_shape, eps) helper = LayerHelper('rms_norm', **locals()) from paddle.base.data_feeder import convert_dtype @@ -539,7 +502,7 @@ def rms_norm( type='rms_norm', inputs=inputs, outputs={'out': out, 'invvar': invvar}, - attrs={'eps': eps}, + attrs={"normalized_shape": normalized_shape, "eps": eps}, ) return out, invvar diff --git a/test/legacy_test/test_rms_norm_op.py b/test/legacy_test/test_rms_norm_op.py index 3bb99510800794..43c5bbad049632 100644 --- a/test/legacy_test/test_rms_norm_op.py +++ b/test/legacy_test/test_rms_norm_op.py @@ -49,9 +49,13 @@ def setUp(self): np.random.seed(2023) x = np.random.randn(*self.x_shape).astype(self.dtype) scale = np.random.randn(self.x_shape[-1]).astype(self.dtype) + normalized_shape = [self.x_shape[-1]] self.inputs = {'x': x, 'scale': scale} - self.attrs = {'epsilon': self.epsilon} + self.attrs = { + 'normalized_shape': normalized_shape, + 'epsilon': self.epsilon, + } y_ref, invvar_ref = rms_norm_reference(x, scale, epsilon=self.epsilon) self.outputs = {'y': y_ref, 'invvar': invvar_ref} @@ -152,5 +156,23 @@ def test_api_dygraph(self): ) +class TestRMSNormValueError(unittest.TestCase): + def test_normalized_shape_type_error(self): + x = paddle.randn([2, 3]) + with self.assertRaises(TypeError): + rms_norm(x, "invalid_shape") + + def test_input_shape_mismatch(self): + x = paddle.randn([2, 3]) + with self.assertRaises(ValueError): + rms_norm(x, [4]) + + def test_weight_shape_mismatch(self): + x = paddle.randn([2, 3]) + weight = paddle.randn([4]) + with self.assertRaises(ValueError): + rms_norm(x, [3], weight=weight) + + if __name__ == '__main__': unittest.main()