Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion aiter/fused_moe_bf16_asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,29 @@ def asm_moe(
dtype = hidden_states.dtype
device = topk_ids.device
lastdim_mul = 8 if w1.dtype in {dtypes.i32, torch.uint32} else 1
is_g1u1 = (
w2.shape[2] * 2 * lastdim_mul == w1.shape[1] and fc2_smooth_scale is not None
)
enable_fp32 = (
fc2_smooth_scale is not None
and is_g1u1
and (inter_dim % 384 == 0 or inter_dim % 320 == 0)
and w1.dtype == dtypes.i8
and a16
)
moebuf_dtype = torch.float32 if enable_fp32 else dtype
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = (
moe_sorting_ck(
topk_ids, topk_weight, global_E, model_dim, dtype, BLOCK_SIZE_M, expert_mask
topk_ids,
topk_weight,
global_E,
model_dim,
moebuf_dtype,
BLOCK_SIZE_M,
expert_mask,
)
)

if fc1_scale is None:
# pure bf16
aiter.fmoe(
Expand Down Expand Up @@ -262,6 +280,8 @@ def asm_moe(
)

# fc2_smooth_scale)
if enable_fp32 and dtype != torch.float32:
moe_buf = moe_buf.to(dtype)
return moe_buf


Expand Down
46 changes: 32 additions & 14 deletions csrc/py_itfs_cu/asm_fmoe.cu
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,8 @@ void fmoe_g1u1_a16(torch::Tensor& out, // [token_cnt, dim]
config_map = &cfg_fmoe_bf16_pertokenInt8_g1u1_silu;
else if(out.dtype() == at::ScalarType::BFloat16 && activation == ActivationType::Gelu)
config_map = &cfg_fmoe_bf16_pertokenInt8_g1u1_gelu;
else if(out.dtype() == at::ScalarType::Float && activation == ActivationType::Silu)
config_map = &cfg_fmoe_fp32_pertokenInt8_g1u1_silu;
else
Comment on lines +774 to 776
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP32 output is now selectable here, but the only shipped FP32 hsaco in this PR is for gfx942 and subGU_n=384 (see hsa/gfx942/fmoe/silu/fmoe_fp32_pertokenInt8_g1u1_silu.csv). If callers pass out as FP32 on other arches or with an inter_dim not divisible by 384, this will fail later with a generic “No suitable kernel found”. Consider adding an explicit early TORCH_CHECK documenting the FP32 constraints (arch + supported inter_dim tiles) so the failure mode is clearer.

Copilot uses AI. Check for mistakes.
TORCH_CHECK(
false, __func__, "Unsupported output dtype or activation type for fmoe_g1u1_a16");
Expand All @@ -793,20 +795,36 @@ void fmoe_g1u1_a16(torch::Tensor& out, // [token_cnt, dim]
TORCH_CHECK(false, __func__, "Unsupported gate dtype for fmoe_g1u1_a16");

impl_ptr = get_heuristic_kernel(inter_dim, sorted_expert_ids.size(0), config_map, 1);
impl_ptr->launch_kernel<uint8_t, uint16_t, true>(out,
input,
gate,
down,
sorted_token_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
topk,
// quant args
fc1_smooth_scale,
fc1_scale,
fc2_scale,
fc2_smooth_scale);
if(out.dtype() == at::ScalarType::Float)
impl_ptr->launch_kernel<uint8_t, float, true>(out,
input,
gate,
down,
sorted_token_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
topk,
// quant args
fc1_smooth_scale,
fc1_scale,
fc2_scale,
fc2_smooth_scale);
else
impl_ptr->launch_kernel<uint8_t, uint16_t, true>(out,
input,
gate,
down,
sorted_token_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
topk,
// quant args
fc1_smooth_scale,
fc1_scale,
fc2_scale,
fc2_smooth_scale);
}

void fmoe_fp8_blockscale_g1u1(torch::Tensor& out, // [token_cnt, dim]
Expand Down
3 changes: 3 additions & 0 deletions hsa/gfx942/fmoe/silu/fmoe_fp32_pertokenInt8_g1u1_silu.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
knl_name,co_name,atm,vskip,smf,tg_num_perCU,ps,subGU_m,subGU_n
_ZN5aiter50fmoe_fp32_pertokenInt8_g1u1_vs_smf_silu_1tg_32x320E,fmoe_fp32_pertokenInt8_g1u1_vs_smf_silu_1tg_32x320.co,1,1,1,1,0,32,320
_ZN5aiter50fmoe_fp32_pertokenInt8_g1u1_vs_smf_silu_1tg_32x384E,fmoe_fp32_pertokenInt8_g1u1_vs_smf_silu_1tg_32x384.co,1,1,1,1,0,32,384
Binary file not shown.
Binary file not shown.
Loading