Skip to content

Commit d0cf397

Browse files
committed
triton softmax modify
1 parent 65e90aa commit d0cf397

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

aiter/ops/triton/softmax.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,38 @@
55

66
_LOGGER = AiterTritonLogger()
77

8+
@triton.jit
9+
def _softmax_kernel(
10+
output_ptr,
11+
input_ptr,
12+
input_row_stride,
13+
output_row_stride,
14+
n_rows,
15+
n_cols,
16+
BLOCK_SIZE: tl.constexpr
17+
):
18+
# Rows of softmax are independent, so we parallelize across these rows
19+
row_idx = tl.program_id(0)
20+
if row_idx >= n_rows:
21+
return
22+
# Stride indicates how many elements we need to increment the pointer by to move to the next row
23+
row_start_ptr = input_ptr + row_idx * input_row_stride
24+
# Block size is the next power of two greater than n_cols, allowing us to
25+
# fit each row into a single block
26+
col_offsets = tl.arange(0, BLOCK_SIZE)
27+
input_ptrs = row_start_ptr + col_offsets
28+
# Load the row into SRAM using a mask since BLOCK_SIZE may be larger than n_cols
29+
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
30+
# Subtract the maximum value for numerical stability
31+
row_minus_max = row - tl.max(row, axis=0)
32+
# Note: Exponential operation in Triton is fast but approximate (i.e., imagine __expf in CUDA)
33+
numerator = tl.exp(row_minus_max)
34+
denominator = tl.sum(numerator, axis=0)
35+
softmax_output = numerator / denominator
36+
# Write the output back to DRAM
37+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
38+
output_ptrs = output_row_start_ptr + col_offsets
39+
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
840

941
@triton.jit
1042
def _softmax_kernel_online(
@@ -16,7 +48,6 @@ def _softmax_kernel_online(
1648
n_cols,
1749
BLOCK_SIZE: tl.constexpr,
1850
):
19-
2051
row_start = tl.program_id(0)
2152
row_idx = row_start
2253

@@ -54,7 +85,6 @@ def _softmax_kernel_online(
5485
output_ptrs = output_row_start_ptr + col_offsets
5586
tl.store(output_ptrs, softmax_output, mask=mask)
5687

57-
5888
def softmax(x):
5989
"""
6090
Computes the row-wise softmax of a 2D input tensor.
@@ -73,17 +103,20 @@ def softmax(x):
73103
n_rows, n_cols = x.shape
74104

75105
MAX_FUSED_SIZE = 65536 // x.element_size()
106+
print("MAX_FUSED_SIZE: ", MAX_FUSED_SIZE)
76107
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
108+
print("BLOCK_SIZE: ", BLOCK_SIZE)
77109
y = torch.empty_like(x)
78110

79-
waves_per_eu = 2
111+
waves_per_eu = 4 # 2
80112
num_warps = 8
81113
num_stages = 2
82114

83115
num_programs = n_rows
84116

85117
grid = lambda meta: (num_programs,) # noqa: E731
86118
_softmax_kernel_online[grid](
119+
# _softmax_kernel[grid](
87120
y,
88121
x,
89122
x.stride(0),

op_tests/triton_tests/test_softmax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,5 @@ def test_softmax(M, N, dtype):
3636
atol, rtol = 1e-5, 1e-5
3737

3838
triton.testing.assert_close(y_triton, y_torch, atol=atol, rtol=rtol)
39+
40+
test_softmax(32768, 8192, "bf16")

0 commit comments

Comments
 (0)