Skip to content

Commit ac93c84

Browse files
[moe training] fix uncoalesced global accesses in per group rowwise scaling kernel
1 parent c3b5a00 commit ac93c84

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ def triton_fp8_per_group_rowwise_scales(
8585
n_groups = offsets.numel()
8686

8787
# allocate on-device buffers for output and scales
88-
output_buffer = torch.empty((m, k), dtype=output_dtype, device=hp_tensor.device)
88+
output_buffer = torch.empty(
89+
(m, k), dtype=output_dtype, device=hp_tensor.device
90+
).as_strided(
91+
(m, k), # shape
92+
(1, m), # stride
93+
)
8994
scales_buffer = torch.empty(
9095
(m * n_groups), dtype=torch.float32, device=hp_tensor.device
9196
)
@@ -114,7 +119,7 @@ def triton_fp8_per_group_rowwise_scales(
114119
round_scales_to_power_of_2,
115120
EPS=EPS,
116121
)
117-
return output_buffer, scales_buffer
122+
return output_buffer.transpose(-2, -1).contiguous().transpose(-2, -1), scales_buffer
118123

119124

120125
@triton_fp8_per_group_rowwise_scales.register_fake
@@ -336,8 +341,8 @@ def _triton_fp8_per_group_colwise_scales_kernel(
336341
offsets_ptr,
337342
out_ptr,
338343
scales_ptr,
344+
M: int,
339345
K: int,
340-
N: int,
341346
stride_input_row: int,
342347
stride_input_col: int,
343348
stride_output_row: int,
@@ -372,7 +377,7 @@ def _triton_fp8_per_group_colwise_scales_kernel(
372377
+ block_col_offs[None, :] * stride_input_col
373378
)
374379
block_mask = (block_row_offs[:, None] < group_row_end_idx) & (
375-
block_col_offs[None, :] < N
380+
block_col_offs[None, :] < K
376381
)
377382
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
378383
input_dtype
@@ -394,8 +399,8 @@ def _triton_fp8_per_group_colwise_scales_kernel(
394399
# store colwise scales for each group in contiguous memory:
395400
# [group0_col0, group_0_col1, ..., group2_col0, group2_col1]
396401
# note: input tensor is in col-major memory layout.
397-
scales_offs = block_col_offs + (N * offset_idx)
398-
scales_mask = tl.arange(0, BLOCK_SIZE) < N
402+
scales_offs = block_col_offs + (K * offset_idx)
403+
scales_mask = tl.arange(0, BLOCK_SIZE) < K
399404
tl.store(scales_ptr + scales_offs, scales, mask=scales_mask)
400405

401406
# perform float8 conversion for this group
@@ -406,7 +411,7 @@ def _triton_fp8_per_group_colwise_scales_kernel(
406411
+ block_col_offs[None, :] * stride_input_col
407412
)
408413
block_mask = (block_row_offs[:, None] < group_row_end_idx) & (
409-
block_col_offs[None, :] < N
414+
block_col_offs[None, :] < K
410415
)
411416
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
412417
input_dtype

0 commit comments

Comments
 (0)