Skip to content

Commit aeafe79

Browse files
committed
Load scaling factors in unswizzled order in 1d kernel
Signed-off-by: Jan Bielak <[email protected]>
1 parent 2dad396 commit aeafe79

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

transformer_engine/common/swizzle/swizzle_block_scaling.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,12 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
9898
const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
9999

100100
// load scaling factors for this lane's initial four 1x128 tiles
101-
const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4);
102101
uint4 sf;
103102
if constexpr (no_oob) {
104-
sf = reinterpret_cast<const uint4*>(warp_src)[lane_load_idx];
103+
sf = reinterpret_cast<const uint4*>(warp_src)[lane];
105104
} else {
106-
if ((out_tile_y < tiles_y - 1) || lane_load_idx < first_oob) {
107-
sf = reinterpret_cast<const uint4*>(warp_src)[lane_load_idx];
105+
if ((out_tile_y < tiles_y - 1) || lane < first_oob) {
106+
sf = reinterpret_cast<const uint4*>(warp_src)[lane];
108107
} else {
109108
sf = uint4{0, 0, 0, 0};
110109
}
@@ -113,8 +112,12 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
113112
// pack the exponent bits of the scaling factors
114113
uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1);
115114

116-
// transpose 4x4 matrices of scaling factors
115+
// partially swizzle the scaling factors
117116
constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches
117+
const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4);
118+
packed_exponents = __shfl_sync(ACTIVE_MASK, packed_exponents, lane_load_idx);
119+
120+
// transpose 4x4 matrices of scaling factors
118121
packed_exponents = transpose_4x4_byte_matrix(packed_exponents, lane % 4, ACTIVE_MASK);
119122

120123
// broadcast the scaling factors for sixteen 1x32 tiles

0 commit comments

Comments
 (0)