Skip to content

Commit 09f1b3d

Browse files
Fix int32 overflow issues for large tensor support in paddle/phi/kernels/impl (#76107) (#76276)
* Fix int32 overflow in svd_grad and conv kernel impl * fix Co-authored-by: Zhan Rongrui <[email protected]>
1 parent 1fa7b0a commit 09f1b3d

27 files changed

+192
-107
lines changed

paddle/phi/kernels/impl/accuracy_check_kernel_impl.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,12 @@ __global__ void AccuracyCheckCUDAKernel(const T* in_data,
143143
const double rtol,
144144
const double atol,
145145
bool equal_nan,
146-
int num,
146+
int64_t num,
147147
bool* out_data) {
148-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
148+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
149149
bool val;
150150
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
151-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
151+
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
152152
const double a = static_cast<MPType>(in_data[i]);
153153
const double b = static_cast<MPType>(other_data[i]);
154154
if (isnan(a) || isnan(b)) {
@@ -173,11 +173,11 @@ __global__ void AccuracyCheckCUDAKernel<phi::complex64>(
173173
const double rtol,
174174
const double atol,
175175
bool equal_nan,
176-
int num,
176+
int64_t num,
177177
bool* out_data) {
178-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
178+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
179179
bool val;
180-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
180+
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
181181
const phi::complex64 a = in_data[i];
182182
const phi::complex64 b = other_data[i];
183183
if (isnan(a) || isnan(b)) {
@@ -203,11 +203,11 @@ __global__ void AccuracyCheckCUDAKernel<phi::complex128>(
203203
const double rtol,
204204
const double atol,
205205
bool equal_nan,
206-
int num,
206+
int64_t num,
207207
bool* out_data) {
208-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
208+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
209209
bool val;
210-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
210+
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
211211
const phi::complex128 a = in_data[i];
212212
const phi::complex128 b = other_data[i];
213213
if (isnan(a) || isnan(b)) {
@@ -236,12 +236,12 @@ struct AccuracyCheckFunctor<phi::GPUContext, T> {
236236
const double atol,
237237
bool equal_nan,
238238
DenseTensor* output) {
239-
int num = in.numel();
239+
int64_t num = in.numel();
240240
const T* in_data = in.data<T>();
241241
const T* other_data = other.data<T>();
242242
bool* out_data = dev_ctx.template Alloc<bool>(output);
243243
int block = 1024;
244-
int grid = (block - 1 + num) / block;
244+
int64_t grid = (block - 1 + num) / block;
245245
grid = (grid > block) ? block : grid;
246246
#ifdef PADDLE_WITH_HIP
247247
hipMemset(out_data, true, num * sizeof(bool));

paddle/phi/kernels/impl/conv_grad_kernel_impl.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ void ConvGradKernel(const Context& dev_ctx,
8585
UpdatePaddingAndDilation<int>(
8686
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
8787

88-
const int batch_size = static_cast<int>(transformed_input.dims()[0]);
88+
const int64_t batch_size = transformed_input.dims()[0];
8989

9090
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
9191
std::vector<int64_t> filter_shape_vec(common::vectorize(filter.dims()));
@@ -125,8 +125,8 @@ void ConvGradKernel(const Context& dev_ctx,
125125

126126
// convolution backward input operator: gemm + col2im(or col2vol)
127127
// convolution backward weight operator: im2col(or vol2col) + gemm
128-
int in_step = static_cast<int>(transformed_input.dims()[1]) / groups;
129-
int out_step = static_cast<int>(transformed_output_grad.dims()[1]) / groups;
128+
int64_t in_step = transformed_input.dims()[1] / groups;
129+
int64_t out_step = transformed_output_grad.dims()[1] / groups;
130130

131131
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
132132

@@ -163,7 +163,7 @@ void ConvGradKernel(const Context& dev_ctx,
163163
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, Context, T> col2im;
164164
phi::funcs::Col2VolFunctor<Context, T> col2vol;
165165

166-
for (int i = 0; i < batch_size; i++) {
166+
for (int64_t i = 0; i < batch_size; i++) {
167167
DenseTensor out_grad_batch =
168168
transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape);
169169
DenseTensor in_grad_batch =
@@ -327,7 +327,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
327327
UpdatePaddingAndDilation(
328328
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
329329

330-
const int batch_size = static_cast<int>(transformed_X.dims()[0]);
330+
const int64_t batch_size = transformed_X.dims()[0];
331331
std::vector<int64_t> filter_shape_vec(common::vectorize(W.dims()));
332332
std::vector<int64_t> output_shape_vec(
333333
common::vectorize(transformed_dY.dims()));
@@ -354,8 +354,8 @@ void ConvGradGradKernel(const Context& dev_ctx,
354354
transformed_dY.dims()[1],
355355
transformed_dY.numel() /
356356
(transformed_dY.dims()[0] * transformed_dY.dims()[1])};
357-
int in_step = static_cast<int>(transformed_X.dims()[1]) / groups;
358-
int out_step = static_cast<int>(transformed_dY.dims()[1]) / groups;
357+
int64_t in_step = transformed_X.dims()[1] / groups;
358+
int64_t out_step = transformed_dY.dims()[1] / groups;
359359

360360
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
361361
DenseTensor col;
@@ -394,7 +394,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
394394
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, Context, T> col2im;
395395
phi::funcs::Col2VolFunctor<Context, T> col2vol;
396396

397-
for (int i = 0; i < batch_size; i++) {
397+
for (int64_t i = 0; i < batch_size; i++) {
398398
DenseTensor dy_batch =
399399
transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
400400
DenseTensor dx_batch = transformed_dX.Slice(i, i + 1).Resize(input_shape);

paddle/phi/kernels/impl/conv_kernel_impl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void ConvKernelImpl(const Context& dev_ctx,
7676
UpdatePaddingAndDilation(
7777
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
7878

79-
const int batch_size = static_cast<int>(transformed_input.dims()[0]);
79+
const int64_t batch_size = transformed_input.dims()[0];
8080

8181
// filter_shape_vec:
8282
// {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
@@ -137,14 +137,14 @@ void ConvKernelImpl(const Context& dev_ctx,
137137
(transformed_output.dims()[0] * transformed_output.dims()[1])};
138138

139139
// convolution operator: im2col(or vol2col) + gemm
140-
int in_step = static_cast<int>(transformed_input.dims()[1]) / groups;
141-
int out_step = static_cast<int>(transformed_output.dims()[1]) / groups;
140+
int64_t in_step = transformed_input.dims()[1] / groups;
141+
int64_t out_step = transformed_output.dims()[1] / groups;
142142

143143
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
144144
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
145145

146146
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
147-
for (int i = 0; i < batch_size; i++) {
147+
for (int64_t i = 0; i < batch_size; i++) {
148148
DenseTensor in_batch =
149149
transformed_input.Slice(i, i + 1).Resize(in_matrix_shape);
150150
DenseTensor out_batch =

paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ void ComputeDDoutWithoutBroadcast(const CPUContext& dev_ctx UNUSED,
259259
auto* y_data = y.data<T>();
260260
auto* out_data = out.data<T>();
261261
auto* ddout_data = ddout->data<T>();
262-
for (int i = 0; i < out_numel; i++) {
262+
for (int64_t i = 0; i < out_numel; i++) {
263263
ddout_data[i] = dout_op(ddx_data[i], ddy_data[i], y_data[i], out_data[i]);
264264
}
265265
}
@@ -283,7 +283,7 @@ void ComputeDDoutWithBroadcast(const CPUContext& dev_ctx UNUSED,
283283
auto* out_data = out.data<T>();
284284
auto* ddout_data = ddout->data<T>();
285285
std::vector<int> index_array(max_dim, 0);
286-
for (int i = 0; i < out_numel; i++) {
286+
for (int64_t i = 0; i < out_numel; i++) {
287287
int x_index = phi::funcs::GetElementwiseIndex(
288288
x_dims_array, max_dim, index_array.data());
289289
int y_index = phi::funcs::GetElementwiseIndex(
@@ -381,9 +381,9 @@ __global__ void ComputeDDoutWithoutBroadcastGPUKernel(const T* ddx_data,
381381
const T* y_data,
382382
const T* out_data,
383383
T* ddout_data,
384-
int numel,
384+
int64_t numel,
385385
DDout_OP dout_op) {
386-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
386+
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
387387
if (tid >= numel) return;
388388
ddout_data[tid] =
389389
dout_op(ddx_data[tid], ddy_data[tid], y_data[tid], out_data[tid]);
@@ -418,16 +418,16 @@ __global__ void ComputeDDoutWithBroadcastGPUKernel(
418418
const T* y_data,
419419
const T* out_data,
420420
T* ddout_data,
421-
int numel,
421+
int64_t numel,
422422
const CudaIntArray x_dims_array,
423423
const CudaIntArray y_dims_array,
424424
const CudaIntArray out_dims_array,
425425
const int max_dim,
426426
DDout_OP dout_op) {
427-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
427+
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
428428
if (tid >= numel) return;
429-
int x_index = 0, y_index = 0, x_index_prod = 1, y_index_prod = 1,
430-
out_index = tid, dim_index;
429+
int64_t x_index = 0, y_index = 0, x_index_prod = 1, y_index_prod = 1,
430+
out_index = tid, dim_index;
431431
for (int64_t i = max_dim - 1; i >= 0; i--) {
432432
if (out_index == 0) break;
433433
dim_index = out_index % out_dims_array[i];

paddle/phi/kernels/impl/fold_grad_kernel_impl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void FoldGradKernel(const Context& dev_ctx,
3838
if (!x_grad) return;
3939

4040
const auto& x_dims = x_grad->dims();
41-
const int batch_size = static_cast<int>(x_dims[0]);
41+
const int64_t batch_size = x_dims[0];
4242

4343
int output_height = (output_sizes[0] + 2 * paddings[0] -
4444
(dilations[0] * (kernel_sizes[0] - 1) + 1)) /
@@ -49,8 +49,8 @@ void FoldGradKernel(const Context& dev_ctx,
4949
strides[1] +
5050
1;
5151

52-
int n_input_plane = x_dims[1];
53-
int n_output_plane = n_input_plane / (kernel_sizes[0] * kernel_sizes[1]);
52+
int64_t n_input_plane = x_dims[1];
53+
int64_t n_output_plane = n_input_plane / (kernel_sizes[0] * kernel_sizes[1]);
5454

5555
DDim out_shape =
5656
common::make_ddim({n_output_plane, output_sizes[0], output_sizes[1]});
@@ -59,7 +59,7 @@ void FoldGradKernel(const Context& dev_ctx,
5959

6060
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
6161

62-
for (int i = 0; i < batch_size; i++) {
62+
for (int64_t i = 0; i < batch_size; i++) {
6363
DenseTensor out_grad_batch = out_grad.Slice(i, i + 1).Resize(out_shape);
6464
DenseTensor x_grad_batch =
6565
x_grad->Slice(i, i + 1).Resize(input_matrix_shape);

paddle/phi/kernels/impl/fold_kernel_impl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void FoldKernel(const Context& dev_ctx,
3333
const std::vector<int>& paddings,
3434
const std::vector<int>& dilations,
3535
DenseTensor* out) {
36-
const int batch_size = static_cast<int>(x.dims()[0]);
36+
const int64_t batch_size = x.dims()[0];
3737
dev_ctx.template Alloc<T>(out);
3838

3939
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, Context, T> col2im;
@@ -48,8 +48,8 @@ void FoldKernel(const Context& dev_ctx,
4848
strides[1] +
4949
1;
5050

51-
int n_input_plane = x_dims[1];
52-
int n_output_plane = n_input_plane / (kernel_sizes[0] * kernel_sizes[1]);
51+
int64_t n_input_plane = x_dims[1];
52+
int64_t n_output_plane = n_input_plane / (kernel_sizes[0] * kernel_sizes[1]);
5353

5454
DDim output_shape =
5555
common::make_ddim({n_output_plane, output_sizes[0], output_sizes[1]});
@@ -60,7 +60,7 @@ void FoldKernel(const Context& dev_ctx,
6060
phi::funcs::SetConstant<Context, T> set_zero;
6161
set_zero(dev_ctx, out, static_cast<T>(0));
6262

63-
for (int i = 0; i < batch_size; i++) {
63+
for (int64_t i = 0; i < batch_size; i++) {
6464
DenseTensor out_batch =
6565
out->Slice(i, i + 1).Resize(output_shape); // im size=3
6666

paddle/phi/kernels/impl/frame_grad_kernel_impl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ void FrameGradKernel(const Context& dev_ctx,
2929
dev_ctx.template Alloc<T>(dx);
3030
const size_t dout_rank = dout.dims().size();
3131
const size_t dx_rank = dx->dims().size();
32-
const int n_frames =
32+
const int64_t n_frames =
3333
(axis == 0) ? dout.dims()[0] : dout.dims()[dout_rank - 1];
34-
const int seq_length = (axis == 0) ? dx->dims()[0] : dx->dims()[dx_rank - 1];
34+
const int64_t seq_length =
35+
(axis == 0) ? dx->dims()[0] : dx->dims()[dx_rank - 1];
3536
DenseTensor dout_tmp = dout;
3637

3738
DDim preserved_dims;

paddle/phi/kernels/impl/frame_kernel_impl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ void FrameKernel(const Context& dev_ctx,
2828
dev_ctx.template Alloc<T>(out);
2929
const size_t x_rank = x.dims().size();
3030
const size_t out_rank = out->dims().size();
31-
const int n_frames = (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
32-
const int seq_length = (axis == 0) ? x.dims()[0] : x.dims()[x_rank - 1];
31+
const int64_t n_frames =
32+
(axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
33+
const int64_t seq_length = (axis == 0) ? x.dims()[0] : x.dims()[x_rank - 1];
3334
// When the number of input dims is larger than 2, it needs to copy
3435
// from x to resize input into 2d and output into 3d. Moreover, output
3536
// dims will be restored at the last step.

paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ void GumbelSoftmaxGradKernel(const Context& dev_ctx,
3030
DenseTensor* dx) {
3131
const int rank = dx->dims().size();
3232
axis = funcs::CanonicalAxis(axis, rank);
33-
int axis_dim = dx->dims()[axis];
33+
int64_t axis_dim = dx->dims()[axis];
34+
3435
// allocate memory on device.
3536

3637
dev_ctx.template Alloc<T>(dx);
@@ -44,6 +45,19 @@ void GumbelSoftmaxGradKernel(const Context& dev_ctx,
4445
return;
4546
}
4647

48+
// TODO(large-tensor): Softmax functor implementation still uses int for
49+
// dimensions. Need to update Softmax functor to support dimensions >
50+
// INT32_MAX.
51+
PADDLE_ENFORCE_LE(
52+
axis_dim,
53+
std::numeric_limits<int>::max(),
54+
common::errors::InvalidArgument(
55+
"The axis dimension (%ld) exceeds the maximum value that int can "
56+
"represent (%d). GumbelSoftmax gradient operation does not support "
57+
"such large tensors yet.",
58+
axis_dim,
59+
std::numeric_limits<int>::max()));
60+
4761
const int size_to_axis = funcs::SizeToAxis(axis, dx->dims());
4862
const int size_from_axis = funcs::SizeFromAxis(axis, dx->dims());
4963
DenseTensor dx_2d(*dx), out_2d(out), dout_2d(dout);

paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include <iostream>
1718
#include <random>
1819

1920
#include "paddle/phi/core/dense_tensor.h"
@@ -52,7 +53,7 @@ void GumbelSoftmaxKernelHelper(const Context& dev_ctx,
5253
DenseTensor* out) {
5354
const int rank = x.dims().size();
5455
axis = funcs::CanonicalAxis(axis, rank);
55-
int axis_dim = x.dims()[axis];
56+
int64_t axis_dim = x.dims()[axis];
5657

5758
PADDLE_ENFORCE_GT(temperature,
5859
0,
@@ -73,6 +74,19 @@ void GumbelSoftmaxKernelHelper(const Context& dev_ctx,
7374
return;
7475
}
7576

77+
// TODO(large-tensor): Softmax functor implementation still uses int for
78+
// dimensions. Need to update Softmax functor to support dimensions >
79+
// INT32_MAX.
80+
PADDLE_ENFORCE_LE(
81+
axis_dim,
82+
std::numeric_limits<int>::max(),
83+
common::errors::InvalidArgument(
84+
"The axis dimension (%ld) exceeds the maximum value that int can "
85+
"represent (%d). GumbelSoftmax operation does not support such "
86+
"large tensors yet.",
87+
axis_dim,
88+
std::numeric_limits<int>::max()));
89+
7690
const int size_to_axis = funcs::SizeToAxis(axis, x.dims());
7791
const int size_from_axis = funcs::SizeFromAxis(axis, x.dims());
7892
DenseTensor x_noise_2d, out_2d(*out);

0 commit comments

Comments
 (0)