@@ -98,13 +98,12 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
98
98
const void * const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
99
99
100
100
// load scaling factors for this lane's initial four 1x128 tiles
101
- const uint32_t lane_load_idx = (lane % 4 ) * 8 + (lane / 4 );
102
101
uint4 sf;
103
102
if constexpr (no_oob) {
104
- sf = reinterpret_cast <const uint4 *>(warp_src)[lane_load_idx ];
103
+ sf = reinterpret_cast <const uint4 *>(warp_src)[lane ];
105
104
} else {
106
- if ((out_tile_y < tiles_y - 1 ) || lane_load_idx < first_oob) {
107
- sf = reinterpret_cast <const uint4 *>(warp_src)[lane_load_idx ];
105
+ if ((out_tile_y < tiles_y - 1 ) || lane < first_oob) {
106
+ sf = reinterpret_cast <const uint4 *>(warp_src)[lane ];
108
107
} else {
109
108
sf = uint4 {0 , 0 , 0 , 0 };
110
109
}
@@ -113,8 +112,12 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
113
112
// pack the exponent bits of the scaling factors
114
113
uint32_t packed_exponents = (sf.x >> 23 ) | (sf.y >> 15 ) | (sf.z >> 7 ) | (sf.w << 1 );
115
114
116
- // transpose 4x4 matrices of scaling factors
115
+ // partially swizzle the scaling factors
117
116
constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF ; // no divergent branches
117
+ const uint32_t lane_load_idx = (lane % 4 ) * 8 + (lane / 4 );
118
+ packed_exponents = __shfl_sync (ACTIVE_MASK, packed_exponents, lane_load_idx);
119
+
120
+ // transpose 4x4 matrices of scaling factors
118
121
packed_exponents = transpose_4x4_byte_matrix (packed_exponents, lane % 4 , ACTIVE_MASK);
119
122
120
123
// broadcast the scaling factors for sixteen 1x32 tiles
0 commit comments