@@ -85,7 +85,12 @@ def triton_fp8_per_group_rowwise_scales(
85
85
n_groups = offsets .numel ()
86
86
87
87
# allocate on-device buffers for output and scales
88
- output_buffer = torch .empty ((m , k ), dtype = output_dtype , device = hp_tensor .device )
88
+ output_buffer = torch .empty (
89
+ (m , k ), dtype = output_dtype , device = hp_tensor .device
90
+ ).as_strided (
91
+ (m , k ), # shape
92
+ (1 , m ), # stride
93
+ )
89
94
scales_buffer = torch .empty (
90
95
(m * n_groups ), dtype = torch .float32 , device = hp_tensor .device
91
96
)
@@ -114,7 +119,7 @@ def triton_fp8_per_group_rowwise_scales(
114
119
round_scales_to_power_of_2 ,
115
120
EPS = EPS ,
116
121
)
117
- return output_buffer , scales_buffer
122
+ return output_buffer . transpose ( - 2 , - 1 ). contiguous (). transpose ( - 2 , - 1 ) , scales_buffer
118
123
119
124
120
125
@triton_fp8_per_group_rowwise_scales .register_fake
@@ -336,8 +341,8 @@ def _triton_fp8_per_group_colwise_scales_kernel(
336
341
offsets_ptr ,
337
342
out_ptr ,
338
343
scales_ptr ,
344
+ M : int ,
339
345
K : int ,
340
- N : int ,
341
346
stride_input_row : int ,
342
347
stride_input_col : int ,
343
348
stride_output_row : int ,
@@ -372,7 +377,7 @@ def _triton_fp8_per_group_colwise_scales_kernel(
372
377
+ block_col_offs [None , :] * stride_input_col
373
378
)
374
379
block_mask = (block_row_offs [:, None ] < group_row_end_idx ) & (
375
- block_col_offs [None , :] < N
380
+ block_col_offs [None , :] < K
376
381
)
377
382
data = tl .load (input_ptr + block_offs , mask = block_mask , other = 0.0 ).to (
378
383
input_dtype
@@ -394,8 +399,8 @@ def _triton_fp8_per_group_colwise_scales_kernel(
394
399
# store colwise scales for each group in contiguous memory:
395
400
# [group0_col0, group_0_col1, ..., group2_col0, group2_col1]
396
401
# note: input tensor is in col-major memory layout.
397
- scales_offs = block_col_offs + (N * offset_idx )
398
- scales_mask = tl .arange (0 , BLOCK_SIZE ) < N
402
+ scales_offs = block_col_offs + (K * offset_idx )
403
+ scales_mask = tl .arange (0 , BLOCK_SIZE ) < K
399
404
tl .store (scales_ptr + scales_offs , scales , mask = scales_mask )
400
405
401
406
# perform float8 conversion for this group
@@ -406,7 +411,7 @@ def _triton_fp8_per_group_colwise_scales_kernel(
406
411
+ block_col_offs [None , :] * stride_input_col
407
412
)
408
413
block_mask = (block_row_offs [:, None ] < group_row_end_idx ) & (
409
- block_col_offs [None , :] < N
414
+ block_col_offs [None , :] < K
410
415
)
411
416
data = tl .load (input_ptr + block_offs , mask = block_mask , other = 0.0 ).to (
412
417
input_dtype
0 commit comments