Problem Description
CK2stages A4W4 dynamic MoE cold JIT: unused small-inter_dim kernels and >50s required-instance compile time
Hi, this is Hanhan who works on Kimi-K2.5-MXFP4 from Anush's team. We recently look at brining up Kimi-K2.5-MXFP4 model for a demo, and hit slow compilation issue for some modules. Attached is the summary that I worked with my agent. I've done my best to isolate the module out and do further analysis. I think mainly there are two issues:
- A gemm kernel can be compiled for ~50 seconds.
- There are unused kernels got compiled, and it increases the compilation time from ~50 seconds to ~110 seconds.
Below is the summary and repro generated by agents and me, and I'm happy to provide more context. Can you take a look how do we improve it? Thanks in advance!
Summary
This issue has two related but separable problems in the A4W4 FP4 preshuffle CK2stages dynamic MoE module:
- The default generated module includes three
gemm2_64x* stage2 kernels for small inter_dim cases. For the Kimi-like path we are investigating, the observed stage2 shape is inter_dim=1024, block_m=128, so those small-inter_dim stage2 kernels are not used by that path.
- Even after locally gating out those unused
gemm2_64x* stage2 kernels, the remaining required CK instance translation units still take about 55 seconds on the cold-build critical path. The slowest required instance is gemm1_256x128..., which alone took 54.274s in a fresh self-run.
So there are two possible fixes/workstreams:
- Split or gate unused small-
inter_dim stage2 kernels so Kimi-like large-inter_dim modules do not compile them.
- Reduce, prebuild, or cache the remaining required
256x128 CK template instances, because pruning unused kernels does not remove the remaining ~55s lower bound.
Environment Used for Fresh Self-Run
Repository checkout:
9bab8388c35936814a659b4ebd245c491e1b940a
Build environment:
ROCm: 7.2.1
GPU_ARCHS=gfx950
MAX_JOBS=16
Container: rocm/atom:rocm7.2.1-ubuntu24.04-pytorch2.9.1-atom0.1.2
The self-run did not require a GPU. It only performed HIP/CK compilation.
Self-contained Reproduction
The following steps should be enough to reproduce the two stock symptoms from an AITER checkout in a ROCm/PyTorch environment with hipcc and ninja available.
No GPU is required for the compile-only repro.
1. Generator-only repro for the unused stage2 kernels
This does not run the long compile. It only shows that the stock generator includes the small-inter_dim gemm2_64x* stage2 kernels in the same A4W4 preshuffle module.
mkdir -p /tmp/ck2_issue_default_blob
GPU_ARCHS=gfx950 python3 csrc/ck_gemm_moe_2stages_codegen/gen_instances.py \
-a fp4x2 -b fp4x2 -c b16 -q per_1x32 -act silu -m 2 \
--preshuffle -w /tmp/ck2_issue_default_blob
find /tmp/ck2_issue_default_blob/instances -type f -name 'moe_ck2stages_gemm2_64x*.cu' \
-printf '%f %s bytes\n' | sort
Expected output:
moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 550 bytes
moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 551 bytes
2. Stock cold-build repro for compile-time breakdown
This compiles the stock A4W4 preshuffle CK2stages module and prints the slowest .ninja_log edges. It uses direct import of aiter/jit/core.py to avoid unrelated package-level imports.
cat >/tmp/repro_ck2_stock_compile.py <<'PY'
import importlib.util
import json
import pathlib
import sys
import time
root = pathlib.Path.cwd()
core_path = root / "aiter" / "jit" / "core.py"
spec = importlib.util.spec_from_file_location("aiter_jit_core_direct", core_path)
core = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = core
spec.loader.exec_module(core)
md = "module_moe_ck2stages_fp4x2_fp4x2_preshuffle_on_b16_silu_per_1x32_mulWeightStage2_"
d = core.get_args_of_build("module_moe_ck2stages")
d["md_name"] = md
d["blob_gen_cmd"] = [
f"{core.AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py "
"-a fp4x2 -b fp4x2 -c b16 -q per_1x32 -act silu -m 2 --preshuffle -w {}"
]
started = time.perf_counter()
core.build_module(
d["md_name"], d["srcs"], d["flags_extra_cc"], d["flags_extra_hip"],
d["blob_gen_cmd"], d["extra_include"], d["extra_ldflags"], False,
d["is_python_module"], d["is_standalone"], d["torch_exclude"],
d.get("third_party", []), d.get("hipify", False),
)
elapsed = time.perf_counter() - started
log = pathlib.Path(core.get_user_jit_dir()) / "build" / md / "build" / ".ninja_log"
rows = []
for line in log.read_text().splitlines():
if line.startswith("#") or not line:
continue
p = line.split("\t")
rows.append({
"duration_s": (int(p[1]) - int(p[0])) / 1000.0,
"start_s": int(p[0]) / 1000.0,
"end_s": int(p[1]) / 1000.0,
"out": p[3],
})
print(json.dumps({
"build_elapsed_s": elapsed,
"ninja_critical_edge_s": max(r["end_s"] for r in rows),
"entries": len(rows),
"top": sorted(rows, key=lambda r: r["duration_s"], reverse=True)[:12],
}, indent=2))
PY
export PYTHONPATH=$PWD
export AITER_META_DIR=$PWD
export AITER_JIT_DIR=/tmp/aiter_ck2_stock_jit
export GPU_ARCHS=gfx950
export MAX_JOBS=16
export AITER_REBUILD=1
rm -rf "$AITER_JIT_DIR"
python3 /tmp/repro_ck2_stock_compile.py
This stock repro should show both:
- the unused small-
inter_dim gemm2_64x* stage2 objects being compiled; and
- slow CK instance compilation, including both unused small-
inter_dim objects and required 256x128 objects.
The later --a4w4-stage2-gt256-only numbers are an optional isolation experiment from a local prototype. They are not required to reproduce the stock issue, but they show that after removing the unused gemm2_64x* stage2 objects, a required gemm1_256x128 instance still dominates the cold-build wall time.
Issue 1: The Default A4W4 Preshuffle Module Generates Unused Small-inter_dim Stage2 Kernels
Minimal generator reproducer:
mkdir -p /tmp/ck2_issue_default_blob
GPU_ARCHS=gfx950 python3 csrc/ck_gemm_moe_2stages_codegen/gen_instances.py \
-a fp4x2 -b fp4x2 -c b16 -q per_1x32 -act silu -m 2 \
--preshuffle -w /tmp/ck2_issue_default_blob
find /tmp/ck2_issue_default_blob/instances -type f -name '*.cu' \
-printf '%f %s bytes\n' | sort
Fresh generator output from this checkout included ten instance TUs:
moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 542 bytes
moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 541 bytes
moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 541 bytes
moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 539 bytes
moe_ck2stages_gemm2_256x128x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 553 bytes
moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 550 bytes
moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 551 bytes
The last three gemm2_64x* stage2 instances are the small-inter_dim stage2 variants.
For the Kimi-like path under investigation, the observed stage2 dispatch shape is:
inter_dim=1024
block_m=128
That path dispatches to the large-inter_dim gemm2_256x128x128x128 branch, not the gemm2_64x* stage2 branches.
Issue 2: After Removing the Unused Stage2 Kernels, the Remaining Required Instances Still Compile Slowly
To isolate the lower bound after removing the unused small-inter_dim stage2 kernels, I used a local prototype generator option:
That prototype generated seven instance TUs:
moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 542 bytes
moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 541 bytes
moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 541 bytes
moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cu 539 bytes
moe_ck2stages_gemm2_256x128x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 553 bytes
moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cu 552 bytes
Fresh self-run result:
build_elapsed_s: 59.80018826806918
ninja_critical_edge_s: 54.877
ninja entries: 10
Top edges from the fresh .ninja_log:
54.274s moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cuda.o
48.473s moe_ck2stages_gemm2_256x128x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cuda.o
39.992s moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cuda.o
37.370s moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cuda.o
33.152s moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cuda.o
32.575s moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16.cuda.o
32.367s moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16.cuda.o
19.702s gemm_moe_ck2stages.cuda.o
19.016s moe_ck_2stages_pybind.cuda.o
0.603s module_moe_ck2stages_fp4x2_fp4x2_preshuffle_on_b16_silu_per_1x32_mulWeightStage2_stage2Gt256Only_selfrun.so
This shows that the remaining ~55s is not caused by compiling seven kernels sequentially. Ninja runs the instance TUs in parallel. The wall-time critical path is dominated by a single required CK instance:
moe_ck2stages_gemm1_256x128x128x128...cuda.o
That one object took 54.274s in this self-run. The link step was only 0.603s.
Why the Tiny Generated .cu Files Still Compile Slowly
Each generated instance .cu is only about 540 bytes, but it includes the heavy CK MoE MXFP4 preshuffle implementation header and explicitly instantiates a large template.
The generated instance body has this shape:
#include "gemm_moe_ck2stages_common_mxfp4.cuh"
using A0DataType = FP4X2;
using B0DataType = FP4X2;
using AccDataType = F32;
using EDataType = B16;
using CDEElementOp = MulABScaleShuffled;
const bool Nswizzle = false;
const bool PerTensorQuant = 3 == static_cast<int>(QuantType::per_Tensor);
const bool MulRoutedWeight = false;
const int ActOP = 1;
CK_MOE_STAGE1_GEMM_DEFINE(256, 128, 128, 128, 1, 4, V3)
CK_MOE_STAGE1_GEMM_DEFINE and CK_MOE_STAGE2_GEMM_DEFINE explicitly instantiate ck_moe_stage{1,2}_gemm<...>, which in turn instantiates ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<...> with many layout, transfer, XDL, shuffle, quantization, and elementwise template parameters.
So the source file is small, but the template instantiation and generated LLVM IR are large.
Expected Behavior
For a Kimi-like inter_dim > 256 path:
- The dynamic module should not eagerly compile the unused
gemm2_64x* small-inter_dim stage2 variants.
- The remaining required module should avoid a 50-60s cold-start compile if possible, either through prebuilt artifacts, cache seeding, finer module specialization, or compiler/CK changes that reduce the required
256x128 instance compile cost.
Proposed Fix Directions
For issue 1:
- Add an official generator/module split for A4W4 preshuffle CK2stages large-
inter_dim modules.
- Keep the module/cache key distinct so full and reduced modules cannot collide.
- For
inter_dim <= 256, either dispatch to a full module or fail fast with a clear error if a reduced module is selected.
For issue 2:
- Treat the remaining
gemm1_256x128 and gemm2_256x128 instances as the current cold-build critical path.
- Consider prebuilding or packaging these required instances for known production shapes.
- Consider shape-aware module generation if runtime only needs
block_m=128.
- Consider CK/compiler work to reduce template instantiation, IR generation, or optimization cost for
DeviceMoeGemmMXBPreShuffle.
Validation Criteria
A fix for issue 1 is sufficient if:
find <generated_blob>/instances -name 'moe_ck2stages_gemm2_64x*.cu'
returns no files for the Kimi-like large-inter_dim module, while the inter_dim > 256 dispatch still uses the same gemm2_256x32/64/128 branches.
A fix for issue 2 is sufficient if:
moe_ck2stages_gemm1_256x128x128x128...cuda.o
moe_ck2stages_gemm2_256x128x128x128...cuda.o
no longer dominate the cold-build critical path, or the module is available through a prebuilt/cache-seeded path so startup does not compile those templates.
Open Questions
- Should AITER expose the large-
inter_dim module split as an env-gated workaround, an explicit API option, or automatic shape-specialized module generation?
- For known deployment shapes, should AITER provide prebuilt CK2stages dynamic modules or a documented cache-seeding recipe?
- Can CK reduce compile-time cost for the required MXFP4 preshuffle
DeviceMoeGemmMXBPreShuffle 256x128 instances without changing runtime performance?
Appendix: Optional Prototype Patch for the gt256-only Isolation Experiment
The stock reproduction above does not require this patch. It reproduces the unused-kernel generation and the slow cold compile from an unmodified checkout.
This optional patch is only for reproducing the isolation experiment where the three gemm2_64x* small-inter_dim stage2 kernels are removed and the remaining required compile time is measured.
After applying the patch, the generator supports:
The direct-build script from the self-contained repro can then be rerun with:
md = "module_moe_ck2stages_fp4x2_fp4x2_preshuffle_on_b16_silu_per_1x32_mulWeightStage2_stage2Gt256Only"
d["blob_gen_cmd"] = [
f"{core.AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py "
"-a fp4x2 -b fp4x2 -c b16 -q per_1x32 -act silu -m 2 "
"--preshuffle --a4w4-stage2-gt256-only -w {}"
]
Prototype patch
diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py
index 4ea3a5c2e..ea2a0d68a 100755
--- a/aiter/ops/moe_op.py
+++ b/aiter/ops/moe_op.py
@@ -4,6 +4,7 @@
import torch
from torch import Tensor
from typing import Optional
+import os
from ..jit.core import compile_ops, AITER_CSRC_DIR
from .enum import ActivationType, Enum, QuantType
from ..utility import dtypes
@@ -509,6 +510,16 @@ def get_moe_stage_module(
quant_type = (
QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type
)
+ stage2_gt256_only = (
+ int(os.environ.get("AITER_MOE_CK2STAGES_A4W4_STAGE2_GT256_ONLY", "0")) != 0
+ and input_dtype == dtypes.fp4x2
+ and weight_dtype == dtypes.fp4x2
+ and preshuffle_mode
+ and quant_type == QuantType.per_1x32
+ and mul_routed_weight_stage == 2
+ and not is_splitk
+ )
+ stage2_gt256_str = "--a4w4-stage2-gt256-only" if stage2_gt256_only else ""
act = str(activation).split(".")[-1].lower()
quant_type = str(quant_type).split(".")[-1].lower()
@@ -524,9 +535,11 @@ def get_moe_stage_module(
]
if is_splitk:
parts.append("splitk")
+ if stage2_gt256_only:
+ parts.append("stage2Gt256Only")
md_name = "_".join(parts)
blob_gen_cmd = [
- f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} {splitk_str} -w {{}}"
+ f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} {splitk_str} {stage2_gt256_str} -w {{}}"
]
return md_name, blob_gen_cmd
diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
index f035a7079..e5e469620 100644
--- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
+++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
@@ -444,6 +444,7 @@ def get_gemm2_kernels_list(
QuantType: str,
MulRoutedWeight: bool,
preshuffle: bool = False,
+ a4w4_stage2_gt256_only: bool = False,
) -> list:
arch = get_gfx()
@@ -478,6 +479,9 @@ def get_gemm2_kernels_list(
else:
raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}")
kernels_list = {k: copy.deepcopy(v) for k, v in gemm2_kernels_dict[tag].items()}
+ if tag == "a4w4" and a4w4_stage2_gt256_only:
+ for kernel_id in (4, 5, 6):
+ kernels_list.pop(kernel_id, None)
for id, kernel in kernels_list.items():
kernel.MulRoutedWeight = MulRoutedWeight
kernel.Nswizzle = Nswizzle
diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py
index b778c74b1..77beb0837 100644
--- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py
+++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py
@@ -673,6 +673,47 @@ A4W4_gemm2_heuristic_dispatch = """
#endif
"""
+A4W4_gemm2_gt256only_heuristic_dispatch = """
+#if defined(__Float4_e2m1fn_x2)
+ if (dtype_checker<{A0DataType}>{{}}(x_dtype)
+ && dtype_checker<{B0DataType}>{{}}(w_dtype)
+ && dtype_checker<{EDataType}>{{}}(y_dtype)
+ && {MulRoutedWeight} == mul_routed_weight_stage
+ && {Quant} == quant
+ && {Preshuffle} == is_shuffled)
+ {{
+ if (inter_dim <= 256)
+ {{
+ TORCH_CHECK(
+ false,
+ "A4W4 stage2 gt256-only CK2stages module does not include inter_dim <= 256 kernels");
+ }}
+ else
+ {{
+ if (block_m == 32)
+ {{
+ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 32, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast<int>(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>;
+ }}
+ else if (block_m == 64)
+ {{
+ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 64, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast<int>(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>;
+ }}
+ else if (block_m == 128)
+ {{
+ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 128, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast<int>(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>;
+ }}
+ else
+ {{
+ TORCH_CHECK(
+ false,
+ "Unsupported block_m value for moe heuristic dispatch: ",
+ block_m);
+ }}
+ }}
+ }}
+#endif
+"""
+
A4W4_bns_gemm2_heuristic_dispatch = """
#if defined(__Float4_e2m1fn_x2)
if (dtype_checker<{A0DataType}>{{}}(x_dtype)
@@ -854,6 +895,7 @@ class ck_moe_2stage_gemm_codegen:
mul_routed_weight_stage,
preshuffle,
splitk,
+ a4w4_stage2_gt256_only=False,
):
self.working_path = working_path
self.a_dtype = a_dtype.upper()
@@ -865,6 +907,7 @@ class ck_moe_2stage_gemm_codegen:
self.nswizzle = False
self.preshuffle = preshuffle
self.splitk = splitk
+ self.a4w4_stage2_gt256_only = a4w4_stage2_gt256_only
def generate_instance_and_lookUpTable(self):
_, gemm1_kernel_list = get_gemm1_kernels_list(
@@ -886,6 +929,7 @@ class ck_moe_2stage_gemm_codegen:
self.quant_type,
self.mul_routed_weight_stage == 2,
self.preshuffle,
+ self.a4w4_stage2_gt256_only,
)
kernel_list = list(gemm1_kernel_list.values()) + list(
gemm2_kernel_list.values()
@@ -979,6 +1023,8 @@ class ck_moe_2stage_gemm_codegen:
gemm1_heuristic_dispatch, gemm2_heuristic_dispatch = heuristic_dispatch_dict[
tag
]
+ if tag == "a4w4" and self.a4w4_stage2_gt256_only:
+ gemm2_heuristic_dispatch = A4W4_gemm2_gt256only_heuristic_dispatch
with open(f_gemm1_heuristic_dispatch, "a") as f_h:
gemm1_fp32 = self.splitk and (quanttype == "_blockscale")
gemm1_heuristic_dispatch_str = gemm1_heuristic_dispatch.format(
@@ -1106,6 +1152,12 @@ if __name__ == "__main__":
help="enable moe_stage1 splitk mode",
)
+ parser.add_argument(
+ "--a4w4-stage2-gt256-only",
+ action="store_true",
+ help="for a4w4 preshuffle, omit stage2 inter_dim <= 256 kernels",
+ )
+
args = parser.parse_args()
args.quant_type = (
"per_1x128" if args.quant_type == "per_128x128" else args.quant_type
@@ -1157,6 +1209,7 @@ if __name__ == "__main__":
routed_weight,
preshuffle_mode,
False, # splitk
+ False,
)
codegen.generate_instance_and_lookUpTable()
@@ -1176,6 +1229,7 @@ if __name__ == "__main__":
routed_weight,
preshuffle_mode,
splitk,
+ False,
)
codegen.generate_instance_and_lookUpTable()
@@ -1201,6 +1255,7 @@ if __name__ == "__main__":
routed_weight,
preshuffle_mode,
False, # splitk
+ False,
)
codegen.generate_instance_and_lookUpTable()
else:
@@ -1216,6 +1271,7 @@ if __name__ == "__main__":
args.mul_routed_weight_stage,
args.preshuffle,
args.issplitk,
+ args.a4w4_stage2_gt256_only,
)
codegen.generate_instance_and_lookUpTable()
Operating System
Ubuntu 24.04.4 LTS (Noble Numbat)
CPU
AMD EPYC 9575F 64-Core Processor
GPU
AMD Instinct MI350X
ROCm Version
ROCm 7.2.1
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
Problem Description
CK2stages A4W4 dynamic MoE cold JIT: unused small-inter_dim kernels and >50s required-instance compile time
Hi, this is Hanhan who works on Kimi-K2.5-MXFP4 from Anush's team. We recently look at brining up Kimi-K2.5-MXFP4 model for a demo, and hit slow compilation issue for some modules. Attached is the summary that I worked with my agent. I've done my best to isolate the module out and do further analysis. I think mainly there are two issues:
Below is the summary and repro generated by agents and me, and I'm happy to provide more context. Can you take a look how do we improve it? Thanks in advance!
Summary
This issue has two related but separable problems in the A4W4 FP4 preshuffle CK2stages dynamic MoE module:
gemm2_64x*stage2 kernels for smallinter_dimcases. For the Kimi-like path we are investigating, the observed stage2 shape isinter_dim=1024, block_m=128, so those small-inter_dimstage2 kernels are not used by that path.gemm2_64x*stage2 kernels, the remaining required CK instance translation units still take about 55 seconds on the cold-build critical path. The slowest required instance isgemm1_256x128..., which alone took 54.274s in a fresh self-run.So there are two possible fixes/workstreams:
inter_dimstage2 kernels so Kimi-like large-inter_dimmodules do not compile them.256x128CK template instances, because pruning unused kernels does not remove the remaining ~55s lower bound.Environment Used for Fresh Self-Run
Repository checkout:
Build environment:
The self-run did not require a GPU. It only performed HIP/CK compilation.
Self-contained Reproduction
The following steps should be enough to reproduce the two stock symptoms from an AITER checkout in a ROCm/PyTorch environment with
hipccandninjaavailable.No GPU is required for the compile-only repro.
1. Generator-only repro for the unused stage2 kernels
This does not run the long compile. It only shows that the stock generator includes the small-
inter_dimgemm2_64x*stage2 kernels in the same A4W4 preshuffle module.Expected output:
2. Stock cold-build repro for compile-time breakdown
This compiles the stock A4W4 preshuffle CK2stages module and prints the slowest
.ninja_logedges. It uses direct import ofaiter/jit/core.pyto avoid unrelated package-level imports.This stock repro should show both:
inter_dimgemm2_64x*stage2 objects being compiled; andinter_dimobjects and required256x128objects.The later
--a4w4-stage2-gt256-onlynumbers are an optional isolation experiment from a local prototype. They are not required to reproduce the stock issue, but they show that after removing the unusedgemm2_64x*stage2 objects, a requiredgemm1_256x128instance still dominates the cold-build wall time.Issue 1: The Default A4W4 Preshuffle Module Generates Unused Small-inter_dim Stage2 Kernels
Minimal generator reproducer:
Fresh generator output from this checkout included ten instance TUs:
The last three
gemm2_64x*stage2 instances are the small-inter_dimstage2 variants.For the Kimi-like path under investigation, the observed stage2 dispatch shape is:
That path dispatches to the large-
inter_dimgemm2_256x128x128x128branch, not thegemm2_64x*stage2 branches.Issue 2: After Removing the Unused Stage2 Kernels, the Remaining Required Instances Still Compile Slowly
To isolate the lower bound after removing the unused small-
inter_dimstage2 kernels, I used a local prototype generator option:That prototype generated seven instance TUs:
Fresh self-run result:
Top edges from the fresh
.ninja_log:This shows that the remaining ~55s is not caused by compiling seven kernels sequentially. Ninja runs the instance TUs in parallel. The wall-time critical path is dominated by a single required CK instance:
That one object took 54.274s in this self-run. The link step was only 0.603s.
Why the Tiny Generated .cu Files Still Compile Slowly
Each generated instance
.cuis only about 540 bytes, but it includes the heavy CK MoE MXFP4 preshuffle implementation header and explicitly instantiates a large template.The generated instance body has this shape:
CK_MOE_STAGE1_GEMM_DEFINEandCK_MOE_STAGE2_GEMM_DEFINEexplicitly instantiateck_moe_stage{1,2}_gemm<...>, which in turn instantiatesck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<...>with many layout, transfer, XDL, shuffle, quantization, and elementwise template parameters.So the source file is small, but the template instantiation and generated LLVM IR are large.
Expected Behavior
For a Kimi-like
inter_dim > 256path:gemm2_64x*small-inter_dimstage2 variants.256x128instance compile cost.Proposed Fix Directions
For issue 1:
inter_dimmodules.inter_dim <= 256, either dispatch to a full module or fail fast with a clear error if a reduced module is selected.For issue 2:
gemm1_256x128andgemm2_256x128instances as the current cold-build critical path.block_m=128.DeviceMoeGemmMXBPreShuffle.Validation Criteria
A fix for issue 1 is sufficient if:
returns no files for the Kimi-like large-
inter_dimmodule, while theinter_dim > 256dispatch still uses the samegemm2_256x32/64/128branches.A fix for issue 2 is sufficient if:
no longer dominate the cold-build critical path, or the module is available through a prebuilt/cache-seeded path so startup does not compile those templates.
Open Questions
inter_dimmodule split as an env-gated workaround, an explicit API option, or automatic shape-specialized module generation?DeviceMoeGemmMXBPreShuffle256x128instances without changing runtime performance?Appendix: Optional Prototype Patch for the gt256-only Isolation Experiment
The stock reproduction above does not require this patch. It reproduces the unused-kernel generation and the slow cold compile from an unmodified checkout.
This optional patch is only for reproducing the isolation experiment where the three
gemm2_64x*small-inter_dimstage2 kernels are removed and the remaining required compile time is measured.After applying the patch, the generator supports:
The direct-build script from the self-contained repro can then be rerun with:
Prototype patch
Operating System
Ubuntu 24.04.4 LTS (Noble Numbat)
CPU
AMD EPYC 9575F 64-Core Processor
GPU
AMD Instinct MI350X
ROCm Version
ROCm 7.2.1
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response