@@ -85,7 +85,7 @@ void ConvGradKernel(const Context& dev_ctx,
8585 UpdatePaddingAndDilation<int >(
8686 &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
8787
88- const int batch_size = static_cast < int >( transformed_input.dims ()[0 ]) ;
88+ const int64_t batch_size = transformed_input.dims ()[0 ];
8989
9090 // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
9191 std::vector<int64_t > filter_shape_vec (common::vectorize (filter.dims ()));
@@ -125,8 +125,8 @@ void ConvGradKernel(const Context& dev_ctx,
125125
126126 // convolution backward input operator: gemm + col2im(or col2vol)
127127 // convolution backward weight operator: im2col(or vol2col) + gemm
128- int in_step = static_cast < int >( transformed_input.dims ()[1 ]) / groups;
129- int out_step = static_cast < int >( transformed_output_grad.dims ()[1 ]) / groups;
128+ int64_t in_step = transformed_input.dims ()[1 ] / groups;
129+ int64_t out_step = transformed_output_grad.dims ()[1 ] / groups;
130130
131131 bool is_expand = IsExpand (filter_shape_vec, strides, paddings, dilations);
132132
@@ -163,7 +163,7 @@ void ConvGradKernel(const Context& dev_ctx,
163163 phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO , Context, T> col2im;
164164 phi::funcs::Col2VolFunctor<Context, T> col2vol;
165165
166- for (int i = 0 ; i < batch_size; i++) {
166+ for (int64_t i = 0 ; i < batch_size; i++) {
167167 DenseTensor out_grad_batch =
168168 transformed_output_grad.Slice (i, i + 1 ).Resize (output_matrix_shape);
169169 DenseTensor in_grad_batch =
@@ -327,7 +327,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
327327 UpdatePaddingAndDilation (
328328 &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
329329
330- const int batch_size = static_cast < int >( transformed_X.dims ()[0 ]) ;
330+ const int64_t batch_size = transformed_X.dims ()[0 ];
331331 std::vector<int64_t > filter_shape_vec (common::vectorize (W.dims ()));
332332 std::vector<int64_t > output_shape_vec (
333333 common::vectorize (transformed_dY.dims ()));
@@ -354,8 +354,8 @@ void ConvGradGradKernel(const Context& dev_ctx,
354354 transformed_dY.dims ()[1 ],
355355 transformed_dY.numel () /
356356 (transformed_dY.dims ()[0 ] * transformed_dY.dims ()[1 ])};
357- int in_step = static_cast < int >( transformed_X.dims ()[1 ]) / groups;
358- int out_step = static_cast < int >( transformed_dY.dims ()[1 ]) / groups;
357+ int64_t in_step = transformed_X.dims ()[1 ] / groups;
358+ int64_t out_step = transformed_dY.dims ()[1 ] / groups;
359359
360360 bool is_expand = IsExpand (filter_shape_vec, strides, paddings, dilations);
361361 DenseTensor col;
@@ -394,7 +394,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
394394 phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO , Context, T> col2im;
395395 phi::funcs::Col2VolFunctor<Context, T> col2vol;
396396
397- for (int i = 0 ; i < batch_size; i++) {
397+ for (int64_t i = 0 ; i < batch_size; i++) {
398398 DenseTensor dy_batch =
399399 transformed_dY.Slice (i, i + 1 ).Resize (output_matrix_shape);
400400 DenseTensor dx_batch = transformed_dX.Slice (i, i + 1 ).Resize (input_shape);
0 commit comments