From 271c509d592a426c4717d258b98d990b5f46c89f Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Fri, 31 Jan 2025 10:20:01 -0800 Subject: [PATCH] DP4AMatMul perf refinements (#23539) In this change 1. Vectorization of k is updated to 4. 2. Tile_A, Tile_B are stored transposed in shared memory. This makes it so that memory locality is improved for our access pattern. 3. Lane output is switched to being individual vectors and its loop unrolled, this solves the problem where laneoutput was not on registers before. Perf improvements are not very consistent with this change. On Tigerlake GPU with 32.0.101.6460 (latest intel drivers) ``` Baseline model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000 Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token): avg (us): 7.36557e+06 <<<< avg (tokens/s): 135.903 p50 (us): 7.35498e+06 stddev (us): 27599 n: 5 * 1001 token(s) With Change model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000 Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token): avg (us): 6.52302e+06 <<<< avg (tokens/s): 153.457 p50 (us): 6.52224e+06 stddev (us): 10407.3 n: 5 * 1001 token(s) ``` However, using the Intel GPA comparing before and after profile, one can clearly see straight runs of ALU work without being interspersed by writebacks to local memory that contained lane_output before. ![image](https://github.com/user-attachments/assets/e01d3474-8406-4a61-b352-2ecbf0855a7f) --- .../webgpu/quantization/matmul_nbits.cc | 170 +++++++++++------- 1 file changed, 107 insertions(+), 63 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 90e6516ff45d1..c79efee65e5c5 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -613,17 +613,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const tile_size_k = 32; const vec_factor = 4; const u32_factor = 4; - const tile_size_k_vec = 4; + const tile_size_k_vec = 2; const block_size = 32; // Shared memory - var tile_A : array, tile_size_k_vec>, tile_size>; // 64 x 32 - var scale_A : array; // 64 x 1 - var tile_B : array, tile_size_k_vec>, tile_size>; // 64 x 32 - var scale_B : array; // 64 x 1 - - // Private memory - var lane_output: array; + var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_A : array; // 64 x 1 + var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_B : array; // 64 x 1 fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) { @@ -632,11 +629,11 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { { return; } - tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col]; + tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; if (col == 0) { - // kidx_v - covers 8 values of k - scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16]; + // kidx_v - covers 16 values of k + scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; } } @@ -648,36 +645,45 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return; } - let b_value = input_b[b_global*uniforms.K8+kidx_v+col]; - var b_value_lower = vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8); - var b_value_upper = vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[row][col][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[row][col][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); if (col == 0) { - // kidx_v - each kidx_v covers 8 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4]; + // kidx_v - each kidx_v covers 16 values of k + scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2]; } } - fn DP4AI(a:vec4, b:vec4) -> i32 + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t { - var local_sum = dot4I8Packed(a[0], b[0]); - local_sum += dot4I8Packed(a[1], b[1]); - local_sum += dot4I8Packed(a[2], b[2]); - local_sum += dot4I8Packed(a[3], b[3]); - return local_sum; + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; } - )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( // During the load phase we use all 256 threads to load 64 rows of A/B. - // For each row we load 4 vectorized elements, which are 32 elements of K. + // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. let a_global_base = workgroup_id.x * tile_size; let b_global_base = workgroup_id.y * tile_size; - let load_row = u32(local_idx/4); - let load_col = u32(local_idx%4); + let load_AorB = u32(local_idx/128); + let load_row = u32((local_idx%128)/2); + let load_col = u32(local_idx%2); // During the compute phase, we have the 64x64 tile split into // subtiles of 16x16. We have a grid of 4x4 subtiles. @@ -689,42 +695,81 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { // For each subtile we have 16 threads assigned. let a_idx = u32(local_idx % subtile_size); - // K's vectrorization is 8 items per index. See input_a/input_b. - // tile_size_k_vec - is the k tile size in vectorized k units/space (1/8). - for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec) + var lane_output1: vec4; + var lane_output2: vec4; + var lane_output3: vec4; + var lane_output4: vec4; + // K's vectrorization is 16 items per index. See input_a/input_b. + // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is + // k tile size is 32. In vectorized space that is 32/16 = 2. + for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) { - // Populate shared memory for the workgroup - loadSHMA(a_global_base, kidx_v, load_row, load_col); - loadSHMB(b_global_base, kidx_v, load_row, load_col); + // Load Phase: Populate shared memory for the workgroup. + if (load_AorB == 0) + { + loadSHMA(a_global_base, kidx_v, load_row, load_col); + } + else + { + loadSHMB(b_global_base, kidx_v, load_row, load_col); + } workgroupBarrier(); - var own_a0: vec4 = vec4(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]); - var own_a1: vec4 = vec4(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]); - var own_scale_a = scale_A[base_A + a_idx]; + // Compute phase: Perform matmul for this subtile 16 x 32 x 16. + // Step 1: Load from shared memory into registers across entire subgroup. + var own_a0: vec4 = tile_A[0][base_A + a_idx]; + var own_a1: vec4 = tile_A[1][base_A + a_idx]; + var own_scale_a: output_element_t = scale_A[base_A + a_idx]; if (sg_size == 16) { - var own_b0: vec4 = vec4(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]); - var own_b1: vec4 = vec4(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]); - var own_scale_b = scale_B[base_B + sg_id]; - for (var col:u32 = 0; col < 16; col++) - { - var local_scale_b = subgroupShuffle(own_scale_b, col); - local_scale_b = local_scale_b * own_scale_a; - var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col)); - local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col)); - lane_output[col] += (output_element_t(local_sum) * local_scale_b); - } + var own_b0: vec4 = tile_B[0][base_B + sg_id]; + var own_b1: vec4 = tile_B[1][base_B + sg_id]; + var own_scale_b: output_element_t = scale_B[base_B + sg_id]; + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); + lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); + lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); + lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); + + lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); + lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); + lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); + lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); + + lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); + lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); + lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); + lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); + + lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); + lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); + lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); + lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); } else { - for (var col:u32 = 0; col < 16; col++) - { - var b0: vec4 = vec4(tile_B[base_B + col][0], tile_B[base_B + col][1]); - var b1: vec4 = vec4(tile_B[base_B + col][2], tile_B[base_B + col][3]); - var local_sum = DP4AI(own_a0, b0); - local_sum += DP4AI(own_a1, b1); - lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]); - } + // Code for other subgroup sizes, simply doesnt use subgroups at all. + // Relies on reads from single location tile_B[][base_B + col] by all + // being optimized by the hardware. + lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); + lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); + lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); + lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); + + lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); + lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); + lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); + lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); + + lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); + lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); + lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); + lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); + + lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); + lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); + lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); + lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); } workgroupBarrier(); } @@ -735,11 +780,10 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { // This creates a shader requirement that uniforms.N % 16 == 0 if (a_global < uniforms.M && b_global < uniforms.N) { - for (var i:u32 = 0; i < 4; i++) - { - let lidx = i * 4; - output[output_idx+i] = vec4(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]); - } + output[output_idx] = lane_output1; + output[output_idx+1] = lane_output2; + output[output_idx+2] = lane_output3; + output[output_idx+3] = lane_output4; } )MAIN_FN"; @@ -812,9 +856,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context mul_program.SetDispatchGroupSize( (M + kTileSize - 1) / kTileSize, (N + kTileSize - 1) / kTileSize, 1); - mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components)}, + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kU32Components)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)},