Skip to content

Conversation

@zrr1999
Copy link
Member

@zrr1999 zrr1999 commented Oct 29, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

排查Paddle/paddle/phi/kernels/impl 目录下的可能存在的大tensor问题并进行修改,主要涉及以下操作:

  • int改int64 或 sizet
  • 添加了 PADDLE_ENFORCE_LE 检查,防止在不支持大tensor的情况下意外使用。
  • 必要的地方添加了注释和TODO。

1. elementwise_grad_kernel_impl.h (+8, -8)

  • CPU 循环索引: int iint64_t i
  • CUDA 内核参数: int numelint64_t numel
  • CUDA 线程索引: int tidint64_t tid,并修正计算方式避免溢出
  • 广播索引变量: int x_index, y_index, ...int64_t ...

2. accuracy_check_kernel_impl.h (+11, -11)

  • CUDA 内核参数: int numint64_t num
  • 线程索引: unsigned int idxint64_t idx,并修正计算方式
  • 循环变量: int iint64_t i
  • 修改了 3 个内核函数:通用模板、complex64 特化、complex128 特化

3. isclose_kernel_impl.h (+7, -5)

  • 修正了 5 个 CUDA 内核的线程索引计算方式
  • 使用 static_cast 避免 blockIdx.x * blockDim.x 的乘法溢出
  • 涉及模板版本和 4 个特化版本

4. renorm_impl.h (+11, -7)

  • 网格大小计算: int gridint64_t grid
  • 添加了网格大小上限检查: std::min(grid, max_grid_dimx)
  • 修正了内核参数从 numeldimension_each

5. unstack_kernel_impl.h (+16, -2)

  • 元素计数: int total_numint64_t total_num
  • int postint64_t post
  • 添加了大张量验证检查: 因为 StackGradFunctorForRange 仍使用 int 索引,所以添加了 PADDLE_ENFORCE_LE 确保元素数不超过 INT32_MAX

6. kldiv_loss_grad_kernel_impl.h (+2, -2)

  • 元素计数: int nint64_t n

7. kldiv_loss_kernel_impl.h (+1, -1)

  • 批次维度: int batch_sizeint64_t batch_size

8. svdvals_grad_kernel_impl.h (+3, -3)

  • 批次计数: int batch_countint64_t batch_count

9. gumbel_softmax_kernel_impl.h (+14, -1)

  • 轴维度: int axis_dimint64_t axis_dim
  • 添加了大张量验证检查: Softmax functor 仍使用 int,添加了维度上限检查

10. gumbel_softmax_grad_kernel_impl.h (+15, -1)

  • 轴维度: int axis_dimint64_t axis_dim
  • 添加了大张量验证检查: 与前向传播类似的检查

11. lrn_kernel_impl.h (+43, -12)

  • 张量维度: int N, C, H, Wint64_t N, C, H, W
  • 添加了头文件: #include <algorithm>
  • 添加了大张量验证检查: GPU 内核仍使用 int,检查所有维度不超过 INT32_MAX
  • 函数签名中的维度参数类型也相应修改

12. frame_kernel_impl.h (+3, -2)

  • 帧数: int n_framesint64_t n_frames
  • 序列长度: int seq_lengthint64_t seq_length

13. frame_grad_kernel_impl.h (+3, -2)

  • 帧数: int n_framesint64_t n_frames
  • 序列长度: int seq_lengthint64_t seq_length

14. stft_kernel_impl.h (+2, -2)

  • 帧数: int n_framesint64_t n_frames
  • 序列长度: int seq_lengthint64_t seq_length

15. stft_grad_kernel_impl.h (+2, -2)

  • 帧数: int n_framesint64_t n_frames
  • 序列长度: int seq_lengthint64_t seq_length

16. fold_kernel_impl.h (+4, -4)

  • 批次大小: int batch_sizeint64_t batch_size
  • 输入平面数: int input_planesint64_t input_planes

17. fold_grad_kernel_impl.h (+4, -4)

  • 批次大小: int batch_sizeint64_t batch_size
  • 输入平面数: int input_planesint64_t input_planes

18. unfold_kernel_impl.h (+2, -2)

  • 批次大小: int batch_sizeint64_t batch_size

19. unfold_grad_kernel_impl.h (+2, -2)

  • 批次大小: int batch_sizeint64_t batch_size

20. lstm_kernel_impl.h (+2, -2)

  • 帧大小: int frame_sizeint64_t frame_size

21. lstsq_kernel_impl.h (+5, -2)

  • 矩阵维度: int m, n, nrhsint64_t m, n, nrhs

22. qr_grad_kernel_impl.h (+2, -2)

  • 矩阵维度: int m, nint64_t m, n

23. spectral_norm_grad_kernel_impl.h (+2, -2)

  • 维度变量: int h, wint64_t h, w

24. spectral_norm_kernel_impl.h (+4, -4)

  • 高度和宽度: int h, wint64_t h, w

25. svd_grad_kernel_impl.h (+11, -10)

  • 矩阵维度: int m, n, kint64_t m, n, k
  • 批次计数: int batch_countint64_t batch_count

26. conv_kernel_impl.h (+4, -4)

  • 批次大小: int batch_sizeint64_t batch_size
  • 步长/块大小: 相关计算变量改为 int64_t

27. conv_grad_kernel_impl.h (+8, -8)

  • 批次大小: int batch_sizeint64_t batch_size
  • 步长/块大小: 相关计算变量改为 int64_t

pcard-93269

@paddle-bot
Copy link

paddle-bot bot commented Oct 29, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zrr1999 zrr1999 requested review from Copilot and wanghuancoder and removed request for Copilot October 31, 2025 07:26
@zrr1999 zrr1999 changed the title Fix int32 overflow in svd_grad and conv kernel impl Fix int32 overflow issues for large tensor support in paddle/phi/kernels/impl Oct 31, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR updates tensor dimension and indexing calculations from int to int64_t to support large tensors that exceed INT32_MAX limits. The changes include adding runtime validation checks where legacy implementations still use int internally, and fixing potential integer overflow issues in CUDA kernel index calculations.

  • Type promotions from int to int64_t for tensor dimensions, batch sizes, and element counts
  • Addition of runtime checks with TODO comments where underlying implementations still use int
  • CUDA kernel index calculation fixes to prevent overflow with explicit casts to int64_t

Reviewed Changes

Copilot reviewed 27 out of 27 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
paddle/phi/kernels/impl/unstack_kernel_impl.h Changed total_num and post to int64_t, added validation for StackGradFunctorForRange int limitation
paddle/phi/kernels/impl/unfold_kernel_impl.h Changed batch_size and loop variable to int64_t
paddle/phi/kernels/impl/unfold_grad_kernel_impl.h Changed batch_size and loop variable to int64_t
paddle/phi/kernels/impl/svdvals_grad_kernel_impl.h Changed rows, cols, and batches to int64_t, removed static_cast
paddle/phi/kernels/impl/svd_grad_kernel_impl.h Changed helper function parameters and dimension variables to int64_t
paddle/phi/kernels/impl/stft_kernel_impl.h Changed n_frames and seq_length from int to size_t
paddle/phi/kernels/impl/stft_grad_kernel_impl.h Changed n_frames and seq_length from int to size_t
paddle/phi/kernels/impl/spectral_norm_kernel_impl.h Changed h and w dimension variables to int64_t
paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h Changed h and w dimension variables to int64_t
paddle/phi/kernels/impl/renorm_impl.h Changed grid and grid2 to int64_t, added max grid size checks, fixed kernel parameter
paddle/phi/kernels/impl/qr_grad_kernel_impl.h Changed m and n dimensions to int64_t
paddle/phi/kernels/impl/lstsq_kernel_impl.h Changed m and n to int64_t, added explanatory comment
paddle/phi/kernels/impl/lstm_kernel_impl.h Changed frame_size to int64_t, removed static_cast
paddle/phi/kernels/impl/lrn_kernel_impl.h Changed N, C, H, W to int64_t, added validation and include for std::max, updated functor signature
paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h Changed n to int64_t
paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h Changed numel and expand to int64_t
paddle/phi/kernels/impl/isclose_kernel_impl.h Fixed CUDA index calculations with explicit int64_t/unsigned int casts
paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h Changed axis_dim to int64_t, added validation and iostream include
paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h Changed axis_dim to int64_t, added validation
paddle/phi/kernels/impl/frame_kernel_impl.h Changed n_frames and seq_length to int64_t
paddle/phi/kernels/impl/frame_grad_kernel_impl.h Changed n_frames and seq_length to int64_t
paddle/phi/kernels/impl/fold_kernel_impl.h Changed batch_size, n_input_plane, n_output_plane and loop variable to int64_t
paddle/phi/kernels/impl/fold_grad_kernel_impl.h Changed batch_size, n_input_plane, n_output_plane and loop variable to int64_t
paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h Changed loop variables and kernel parameters to int64_t, fixed CUDA index calculations
paddle/phi/kernels/impl/conv_kernel_impl.h Changed batch_size, in_step, out_step and loop variable to int64_t
paddle/phi/kernels/impl/conv_grad_kernel_impl.h Changed batch_size, in_step, out_step and loop variables to int64_t
paddle/phi/kernels/impl/accuracy_check_kernel_impl.h Changed kernel parameters and variables to int64_t, fixed CUDA index calculations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

total_num,
std::numeric_limits<int>::max(),
common::errors::InvalidArgument(
"The total number of elements in UnStack is %d, which exceeds the "
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The format specifier %d is used for an int64_t value total_num. This should be %ld or %lld depending on platform to correctly format int64_t values.

Suggested change
"The total number of elements in UnStack is %d, which exceeds the "
"The total number of elements in UnStack is %lld, which exceeds the "

Copilot uses AI. Check for mistakes.

#pragma once

#include <iostream>
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The <iostream> header is included but does not appear to be used in this file. This include should be removed unless it's needed for debugging purposes that were accidentally left in.

Suggested change
#include <iostream>

Copilot uses AI. Check for mistakes.
Comment on lines +300 to +301
RenormKernelFunc3<T><<<grid2, block2, 0, stream>>>(
dimension_each, dim_value_data, p, max_norm);
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first parameter was changed from numel to dimension_each. This appears to be a logic change beyond just type conversion and should be verified as correct. If this fixes a bug, it should be documented in the commit message or comments.

Copilot uses AI. Check for mistakes.
std::max({N, C, H, W}),
std::numeric_limits<int>::max(),
common::errors::InvalidArgument(
"One or more tensor dimensions (N=%ld, C=%ld, H=%ld, W=%ld) exceeds "
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The format specifier %ld is used for int64_t values. On some platforms (e.g., Windows), int64_t is long long not long, so %lld should be used instead, or use a portable format like %\" PRId64 \" from .

Copilot uses AI. Check for mistakes.
axis_dim,
std::numeric_limits<int>::max(),
common::errors::InvalidArgument(
"The axis dimension (%ld) exceeds the maximum value that int can "
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The format specifier %ld is used for an int64_t value. This should use %lld or a portable format like %\" PRId64 \" from for cross-platform compatibility.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wanghuancoder wanghuancoder merged commit c6bed21 into PaddlePaddle:develop Nov 3, 2025
80 of 82 checks passed
zhengshengning pushed a commit to zhengshengning/Paddle that referenced this pull request Nov 6, 2025
…els/impl (PaddlePaddle#76107)

* Fix int32 overflow in svd_grad and conv kernel impl

* fix
@PaddlePaddle PaddlePaddle deleted a comment from Copilot AI Nov 7, 2025
zhengshengning added a commit that referenced this pull request Nov 7, 2025
…els/impl (#76107) (#76276)

* Fix int32 overflow in svd_grad and conv kernel impl

* fix

Co-authored-by: Zhan Rongrui <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants