Skip to content

[Issue]: gfx950 slow compilation for one of Kimi-K2.5-MXFP4 module #3566

@hanhanW

Description

@hanhanW

Problem Description

CK2stages A4W4 dynamic MoE cold JIT: unused small-inter_dim kernels and >50s required-instance compile time

Hi, this is Hanhan who works on Kimi-K2.5-MXFP4 from Anush's team. We recently look at brining up Kimi-K2.5-MXFP4 model for a demo, and hit slow compilation issue for some modules. Attached is the summary that I worked with my agent. I've done my best to isolate the module out and do further analysis. I think mainly there are two issues:

  1. A gemm kernel can be compiled for ~50 seconds.
  2. There are unused kernels got compiled, and it increases the compilation time from ~50 seconds to ~110 seconds.

Below is the summary and repro generated by agents and me, and I'm happy to provide more context. Can you take a look how do we improve it? Thanks in advance!

Summary

This issue has two related but separable problems in the A4W4 FP4 preshuffle CK2stages dynamic MoE module:

  1. The default generated module includes three gemm2_64x* stage2 kernels for small inter_dim cases. For the Kimi-like path we are investigating, the observed stage2 shape is inter_dim=1024, block_m=128, so those small-inter_dim stage2 kernels are not used by that path.
  2. Even after locally gating out those unused gemm2_64x* stage2 kernels, the remaining required CK instance translation units still take about 55 seconds on the cold-build critical path. The slowest required instance is gemm1_256x128..., which alone took 54.274s in a fresh self-run.

So there are two possible fixes/workstreams:

  • Split or gate unused small-inter_dim stage2 kernels so Kimi-like large-inter_dim modules do not compile them.
  • Reduce, prebuild, or cache the remaining required 256x128 CK template instances, because pruning unused kernels does not remove the remaining ~55s lower bound.

Environment Used for Fresh Self-Run

Repository checkout:

9bab8388c35936814a659b4ebd245c491e1b940a

Build environment:

ROCm: 7.2.1
GPU_ARCHS=gfx950
MAX_JOBS=16
Container: rocm/atom:rocm7.2.1-ubuntu24.04-pytorch2.9.1-atom0.1.2

The self-run did not require a GPU. It only performed HIP/CK compilation.

Self-contained Reproduction

The following steps should be enough to reproduce the two stock symptoms from an AITER checkout in a ROCm/PyTorch environment with hipcc and ninja available.

No GPU is required for the compile-only repro.

1. Generator-only repro for the unused stage2 kernels

This does not run the long compile. It only shows that the stock generator includes the small-inter_dim gemm2_64x* stage2 kernels in the same A4W4 preshuffle module.

mkdir -p /tmp/ck2_issue_default_blob
GPU_ARCHS=gfx950 python3 csrc/ck_gemm_moe_2stages_codegen/gen_instances.py \
  -a fp4x2 -b fp4x2 -c b16 -q per_1x32 -act silu -m 2 \
  --preshuffle -w /tmp/ck2_issue_default_blob

find /tmp/ck2_issue_default_blob/instances -type f -name 'moe_ck2stages_gemm2_64x*.cu' \
  -printf '%f %s bytes\n' | sort

Expected output:

moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 550 bytes
moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 551 bytes

2. Stock cold-build repro for compile-time breakdown

This compiles the stock A4W4 preshuffle CK2stages module and prints the slowest .ninja_log edges. It uses direct import of aiter/jit/core.py to avoid unrelated package-level imports.

cat >/tmp/repro_ck2_stock_compile.py <<'PY'
import importlib.util
import json
import pathlib
import sys
import time

root = pathlib.Path.cwd()
core_path = root / "aiter" / "jit" / "core.py"
spec = importlib.util.spec_from_file_location("aiter_jit_core_direct", core_path)
core = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = core
spec.loader.exec_module(core)

md = "module_moe_ck2stages_fp4x2_fp4x2_preshuffle_on_b16_silu_per_1x32_mulWeightStage2_"
d = core.get_args_of_build("module_moe_ck2stages")
d["md_name"] = md
d["blob_gen_cmd"] = [
    f"{core.AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py "
    "-a fp4x2 -b fp4x2 -c b16 -q per_1x32 -act silu -m 2 --preshuffle -w {}"
]

started = time.perf_counter()
core.build_module(
    d["md_name"], d["srcs"], d["flags_extra_cc"], d["flags_extra_hip"],
    d["blob_gen_cmd"], d["extra_include"], d["extra_ldflags"], False,
    d["is_python_module"], d["is_standalone"], d["torch_exclude"],
    d.get("third_party", []), d.get("hipify", False),
)
elapsed = time.perf_counter() - started

log = pathlib.Path(core.get_user_jit_dir()) / "build" / md / "build" / ".ninja_log"
rows = []
for line in log.read_text().splitlines():
    if line.startswith("#") or not line:
        continue
    p = line.split("\t")
    rows.append({
        "duration_s": (int(p[1]) - int(p[0])) / 1000.0,
        "start_s": int(p[0]) / 1000.0,
        "end_s": int(p[1]) / 1000.0,
        "out": p[3],
    })

print(json.dumps({
    "build_elapsed_s": elapsed,
    "ninja_critical_edge_s": max(r["end_s"] for r in rows),
    "entries": len(rows),
    "top": sorted(rows, key=lambda r: r["duration_s"], reverse=True)[:12],
}, indent=2))
PY

export PYTHONPATH=$PWD
export AITER_META_DIR=$PWD
export AITER_JIT_DIR=/tmp/aiter_ck2_stock_jit
export GPU_ARCHS=gfx950
export MAX_JOBS=16
export AITER_REBUILD=1
rm -rf "$AITER_JIT_DIR"
python3 /tmp/repro_ck2_stock_compile.py

This stock repro should show both:

  • the unused small-inter_dim gemm2_64x* stage2 objects being compiled; and
  • slow CK instance compilation, including both unused small-inter_dim objects and required 256x128 objects.

The later --a4w4-stage2-gt256-only numbers are an optional isolation experiment from a local prototype. They are not required to reproduce the stock issue, but they show that after removing the unused gemm2_64x* stage2 objects, a required gemm1_256x128 instance still dominates the cold-build wall time.

Issue 1: The Default A4W4 Preshuffle Module Generates Unused Small-inter_dim Stage2 Kernels

Minimal generator reproducer:

mkdir -p /tmp/ck2_issue_default_blob
GPU_ARCHS=gfx950 python3 csrc/ck_gemm_moe_2stages_codegen/gen_instances.py \
  -a fp4x2 -b fp4x2 -c b16 -q per_1x32 -act silu -m 2 \
  --preshuffle -w /tmp/ck2_issue_default_blob

find /tmp/ck2_issue_default_blob/instances -type f -name '*.cu' \
  -printf '%f %s bytes\n' | sort

Fresh generator output from this checkout included ten instance TUs:

moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 542 bytes
moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 541 bytes
moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 541 bytes
moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 539 bytes
moe_ck2stages_gemm2_256x128x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 553 bytes
moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 550 bytes
moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 551 bytes

The last three gemm2_64x* stage2 instances are the small-inter_dim stage2 variants.

For the Kimi-like path under investigation, the observed stage2 dispatch shape is:

inter_dim=1024
block_m=128

That path dispatches to the large-inter_dim gemm2_256x128x128x128 branch, not the gemm2_64x* stage2 branches.

Issue 2: After Removing the Unused Stage2 Kernels, the Remaining Required Instances Still Compile Slowly

To isolate the lower bound after removing the unused small-inter_dim stage2 kernels, I used a local prototype generator option:

--a4w4-stage2-gt256-only

That prototype generated seven instance TUs:

moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 542 bytes
moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 541 bytes
moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 541 bytes
moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 539 bytes
moe_ck2stages_gemm2_256x128x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 553 bytes
moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes

Fresh self-run result:

build_elapsed_s:       59.80018826806918
ninja_critical_edge_s: 54.877
ninja entries:         10

Top edges from the fresh .ninja_log:

54.274s moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cuda.o
48.473s moe_ck2stages_gemm2_256x128x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cuda.o
39.992s moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cuda.o
37.370s moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cuda.o
33.152s moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cuda.o
32.575s moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cuda.o
32.367s moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cuda.o
19.702s gemm_moe_ck2stages.cuda.o
19.016s moe_ck_2stages_pybind.cuda.o
 0.603s module_moe_ck2stages_fp4x2_fp4x2_preshuffle_on_b16_silu_per_1x32_mulWeightStage2_stage2Gt256Only_selfrun.so

This shows that the remaining ~55s is not caused by compiling seven kernels sequentially. Ninja runs the instance TUs in parallel. The wall-time critical path is dominated by a single required CK instance:

moe_ck2stages_gemm1_256x128x128x128...cuda.o

That one object took 54.274s in this self-run. The link step was only 0.603s.

Why the Tiny Generated .cu Files Still Compile Slowly

Each generated instance .cu is only about 540 bytes, but it includes the heavy CK MoE MXFP4 preshuffle implementation header and explicitly instantiates a large template.

The generated instance body has this shape:

#include "gemm_moe_ck2stages_common_mxfp4.cuh"

using A0DataType = FP4X2;
using B0DataType = FP4X2;
using AccDataType = F32;
using EDataType = B16;
using CDEElementOp = MulABScaleShuffled;
const bool Nswizzle = false;
const bool PerTensorQuant = 3 == static_cast<int>(QuantType::per_Tensor);
const bool MulRoutedWeight = false;
const int ActOP = 1;
CK_MOE_STAGE1_GEMM_DEFINE(256, 128, 128, 128, 1, 4, V3)

CK_MOE_STAGE1_GEMM_DEFINE and CK_MOE_STAGE2_GEMM_DEFINE explicitly instantiate ck_moe_stage{1,2}_gemm<...>, which in turn instantiates ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<...> with many layout, transfer, XDL, shuffle, quantization, and elementwise template parameters.

So the source file is small, but the template instantiation and generated LLVM IR are large.

Expected Behavior

For a Kimi-like inter_dim > 256 path:

  1. The dynamic module should not eagerly compile the unused gemm2_64x* small-inter_dim stage2 variants.
  2. The remaining required module should avoid a 50-60s cold-start compile if possible, either through prebuilt artifacts, cache seeding, finer module specialization, or compiler/CK changes that reduce the required 256x128 instance compile cost.

Proposed Fix Directions

For issue 1:

  • Add an official generator/module split for A4W4 preshuffle CK2stages large-inter_dim modules.
  • Keep the module/cache key distinct so full and reduced modules cannot collide.
  • For inter_dim <= 256, either dispatch to a full module or fail fast with a clear error if a reduced module is selected.

For issue 2:

  • Treat the remaining gemm1_256x128 and gemm2_256x128 instances as the current cold-build critical path.
  • Consider prebuilding or packaging these required instances for known production shapes.
  • Consider shape-aware module generation if runtime only needs block_m=128.
  • Consider CK/compiler work to reduce template instantiation, IR generation, or optimization cost for DeviceMoeGemmMXBPreShuffle.

Validation Criteria

A fix for issue 1 is sufficient if:

find <generated_blob>/instances -name 'moe_ck2stages_gemm2_64x*.cu'

returns no files for the Kimi-like large-inter_dim module, while the inter_dim > 256 dispatch still uses the same gemm2_256x32/64/128 branches.

A fix for issue 2 is sufficient if:

moe_ck2stages_gemm1_256x128x128x128...cuda.o
moe_ck2stages_gemm2_256x128x128x128...cuda.o

no longer dominate the cold-build critical path, or the module is available through a prebuilt/cache-seeded path so startup does not compile those templates.

Open Questions

  • Should AITER expose the large-inter_dim module split as an env-gated workaround, an explicit API option, or automatic shape-specialized module generation?
  • For known deployment shapes, should AITER provide prebuilt CK2stages dynamic modules or a documented cache-seeding recipe?
  • Can CK reduce compile-time cost for the required MXFP4 preshuffle DeviceMoeGemmMXBPreShuffle 256x128 instances without changing runtime performance?

Appendix: Optional Prototype Patch for the gt256-only Isolation Experiment

The stock reproduction above does not require this patch. It reproduces the unused-kernel generation and the slow cold compile from an unmodified checkout.

This optional patch is only for reproducing the isolation experiment where the three gemm2_64x* small-inter_dim stage2 kernels are removed and the remaining required compile time is measured.

After applying the patch, the generator supports:

--a4w4-stage2-gt256-only

The direct-build script from the self-contained repro can then be rerun with:

md = "module_moe_ck2stages_fp4x2_fp4x2_preshuffle_on_b16_silu_per_1x32_mulWeightStage2_stage2Gt256Only"
d["blob_gen_cmd"] = [
    f"{core.AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py "
    "-a fp4x2 -b fp4x2 -c b16 -q per_1x32 -act silu -m 2 "
    "--preshuffle --a4w4-stage2-gt256-only -w {}"
]
Prototype patch
diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py
index 4ea3a5c2e..ea2a0d68a 100755
--- a/aiter/ops/moe_op.py
+++ b/aiter/ops/moe_op.py
@@ -4,6 +4,7 @@
 import torch
 from torch import Tensor
 from typing import Optional
+import os
 from ..jit.core import compile_ops, AITER_CSRC_DIR
 from .enum import ActivationType, Enum, QuantType
 from ..utility import dtypes
@@ -509,6 +510,16 @@ def get_moe_stage_module(
     quant_type = (
         QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type
     )
+    stage2_gt256_only = (
+        int(os.environ.get("AITER_MOE_CK2STAGES_A4W4_STAGE2_GT256_ONLY", "0")) != 0
+        and input_dtype == dtypes.fp4x2
+        and weight_dtype == dtypes.fp4x2
+        and preshuffle_mode
+        and quant_type == QuantType.per_1x32
+        and mul_routed_weight_stage == 2
+        and not is_splitk
+    )
+    stage2_gt256_str = "--a4w4-stage2-gt256-only" if stage2_gt256_only else ""
     act = str(activation).split(".")[-1].lower()
     quant_type = str(quant_type).split(".")[-1].lower()
 
@@ -524,9 +535,11 @@ def get_moe_stage_module(
     ]
     if is_splitk:
         parts.append("splitk")
+    if stage2_gt256_only:
+        parts.append("stage2Gt256Only")
     md_name = "_".join(parts)
     blob_gen_cmd = [
-        f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} {splitk_str} -w {{}}"
+        f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} {splitk_str} {stage2_gt256_str} -w {{}}"
     ]
 
     return md_name, blob_gen_cmd
diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
index f035a7079..e5e469620 100644
--- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
+++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
@@ -444,6 +444,7 @@ def get_gemm2_kernels_list(
     QuantType: str,
     MulRoutedWeight: bool,
     preshuffle: bool = False,
+    a4w4_stage2_gt256_only: bool = False,
 ) -> list:
     arch = get_gfx()
 
@@ -478,6 +479,9 @@ def get_gemm2_kernels_list(
     else:
         raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}")
     kernels_list = {k: copy.deepcopy(v) for k, v in gemm2_kernels_dict[tag].items()}
+    if tag == "a4w4" and a4w4_stage2_gt256_only:
+        for kernel_id in (4, 5, 6):
+            kernels_list.pop(kernel_id, None)
     for id, kernel in kernels_list.items():
         kernel.MulRoutedWeight = MulRoutedWeight
         kernel.Nswizzle = Nswizzle
diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py
index b778c74b1..77beb0837 100644
--- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py
+++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py
@@ -673,6 +673,47 @@ A4W4_gemm2_heuristic_dispatch = """
 #endif
 """
 
+A4W4_gemm2_gt256only_heuristic_dispatch = """
+#if defined(__Float4_e2m1fn_x2)
+    if (dtype_checker<{A0DataType}>{{}}(x_dtype)
+        && dtype_checker<{B0DataType}>{{}}(w_dtype)
+        && dtype_checker<{EDataType}>{{}}(y_dtype)
+        && {MulRoutedWeight} == mul_routed_weight_stage
+        && {Quant} == quant
+        && {Preshuffle} == is_shuffled)
+    {{
+        if (inter_dim <= 256)
+        {{
+            TORCH_CHECK(
+                false,
+                "A4W4 stage2 gt256-only CK2stages module does not include inter_dim <= 256 kernels");
+        }}
+        else
+        {{
+            if (block_m == 32)
+            {{
+                return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 32, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast<int>(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>;
+            }}
+            else if (block_m == 64)
+            {{
+                return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 64, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast<int>(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>;
+            }}
+            else if (block_m == 128)
+            {{
+                return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 128, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast<int>(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>;
+            }}
+            else
+            {{
+                TORCH_CHECK(
+                    false,
+                    "Unsupported block_m value for moe heuristic dispatch: ",
+                    block_m);
+            }}
+        }}
+    }}
+#endif
+"""
+
 A4W4_bns_gemm2_heuristic_dispatch = """
 #if defined(__Float4_e2m1fn_x2)
     if (dtype_checker<{A0DataType}>{{}}(x_dtype)
@@ -854,6 +895,7 @@ class ck_moe_2stage_gemm_codegen:
         mul_routed_weight_stage,
         preshuffle,
         splitk,
+        a4w4_stage2_gt256_only=False,
     ):
         self.working_path = working_path
         self.a_dtype = a_dtype.upper()
@@ -865,6 +907,7 @@ class ck_moe_2stage_gemm_codegen:
         self.nswizzle = False
         self.preshuffle = preshuffle
         self.splitk = splitk
+        self.a4w4_stage2_gt256_only = a4w4_stage2_gt256_only
 
     def generate_instance_and_lookUpTable(self):
         _, gemm1_kernel_list = get_gemm1_kernels_list(
@@ -886,6 +929,7 @@ class ck_moe_2stage_gemm_codegen:
             self.quant_type,
             self.mul_routed_weight_stage == 2,
             self.preshuffle,
+            self.a4w4_stage2_gt256_only,
         )
         kernel_list = list(gemm1_kernel_list.values()) + list(
             gemm2_kernel_list.values()
@@ -979,6 +1023,8 @@ class ck_moe_2stage_gemm_codegen:
         gemm1_heuristic_dispatch, gemm2_heuristic_dispatch = heuristic_dispatch_dict[
             tag
         ]
+        if tag == "a4w4" and self.a4w4_stage2_gt256_only:
+            gemm2_heuristic_dispatch = A4W4_gemm2_gt256only_heuristic_dispatch
         with open(f_gemm1_heuristic_dispatch, "a") as f_h:
             gemm1_fp32 = self.splitk and (quanttype == "_blockscale")
             gemm1_heuristic_dispatch_str = gemm1_heuristic_dispatch.format(
@@ -1106,6 +1152,12 @@ if __name__ == "__main__":
         help="enable moe_stage1 splitk mode",
     )
 
+    parser.add_argument(
+        "--a4w4-stage2-gt256-only",
+        action="store_true",
+        help="for a4w4 preshuffle, omit stage2 inter_dim <= 256 kernels",
+    )
+
     args = parser.parse_args()
     args.quant_type = (
         "per_1x128" if args.quant_type == "per_128x128" else args.quant_type
@@ -1157,6 +1209,7 @@ if __name__ == "__main__":
                 routed_weight,
                 preshuffle_mode,
                 False,  # splitk
+                False,
             )
             codegen.generate_instance_and_lookUpTable()
 
@@ -1176,6 +1229,7 @@ if __name__ == "__main__":
                 routed_weight,
                 preshuffle_mode,
                 splitk,
+                False,
             )
             codegen.generate_instance_and_lookUpTable()
 
@@ -1201,6 +1255,7 @@ if __name__ == "__main__":
                 routed_weight,
                 preshuffle_mode,
                 False,  # splitk
+                False,
             )
             codegen.generate_instance_and_lookUpTable()
     else:
@@ -1216,6 +1271,7 @@ if __name__ == "__main__":
                 args.mul_routed_weight_stage,
                 args.preshuffle,
                 args.issplitk,
+                args.a4w4_stage2_gt256_only,
             )
             codegen.generate_instance_and_lookUpTable()

Operating System

Ubuntu 24.04.4 LTS (Noble Numbat)

CPU

AMD EPYC 9575F 64-Core Processor

GPU

AMD Instinct MI350X

ROCm Version

ROCm 7.2.1

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions