diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 1859d8a5cb..8359238289 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -94,7 +94,6 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP #pragma unroll for (int i2 = 0; i2 < nvec; ++i2) { const int row = tile_row + i1 * nvec + i2; - size_t row_offset = static_cast(row) * row_length; const int col = tile_col + j1 * nvec; Vec local_input; Vec local_output; @@ -102,7 +101,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - local_input.data.elt[j2] = input[row_offset + col + j2]; + local_input.data.elt[j2] = input[static_cast(row) * row_length + col + j2]; } } } @@ -113,14 +112,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - output[row_offset + col + j2] = local_output.data.elt[j2]; + output[static_cast(row) * row_length + col + j2] = local_output.data.elt[j2]; } } } else if (row < padded_num_rows) { // padding for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - output[row_offset + col + j2] = local_zero; + output[static_cast(row) * row_length + col + j2] = local_zero; } } } @@ -179,7 +178,6 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult #pragma unroll for (int i2 = 0; i2 < nvec; ++i2) { const int row = tile_row + i1 * nvec + i2; - size_t row_offset = static_cast(row) * row_length; const int col = tile_col + j1 * nvec; Vec local_input; Vec local_output; @@ -187,7 +185,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - local_input.data.elt[j2] = input[row_offset + col + j2]; + local_input.data.elt[j2] = input[static_cast(row) * row_length + col + j2]; } } } @@ -198,7 +196,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - output[row_offset + col + j2] = local_output.data.elt[j2]; + output[static_cast(row) * row_length + col + j2] = local_output.data.elt[j2]; } } }