55
66_LOGGER = AiterTritonLogger ()
77
8+ @triton .jit
9+ def _softmax_kernel (
10+ output_ptr ,
11+ input_ptr ,
12+ input_row_stride ,
13+ output_row_stride ,
14+ n_rows ,
15+ n_cols ,
16+ BLOCK_SIZE : tl .constexpr
17+ ):
18+ # Rows of softmax are independent, so we parallelize across these rows
19+ row_idx = tl .program_id (0 )
20+ if row_idx >= n_rows :
21+ return
22+ # Stride indicates how many elements we need to increment the pointer by to move to the next row
23+ row_start_ptr = input_ptr + row_idx * input_row_stride
24+ # Block size is the next power of two greater than n_cols, allowing us to
25+ # fit each row into a single block
26+ col_offsets = tl .arange (0 , BLOCK_SIZE )
27+ input_ptrs = row_start_ptr + col_offsets
28+ # Load the row into SRAM using a mask since BLOCK_SIZE may be larger than n_cols
29+ row = tl .load (input_ptrs , mask = col_offsets < n_cols , other = - float ('inf' ))
30+ # Subtract the maximum value for numerical stability
31+ row_minus_max = row - tl .max (row , axis = 0 )
32+ # Note: Exponential operation in Triton is fast but approximate (i.e., imagine __expf in CUDA)
33+ numerator = tl .exp (row_minus_max )
34+ denominator = tl .sum (numerator , axis = 0 )
35+ softmax_output = numerator / denominator
36+ # Write the output back to DRAM
37+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
38+ output_ptrs = output_row_start_ptr + col_offsets
39+ tl .store (output_ptrs , softmax_output , mask = col_offsets < n_cols )
840
941@triton .jit
1042def _softmax_kernel_online (
@@ -16,7 +48,6 @@ def _softmax_kernel_online(
1648 n_cols ,
1749 BLOCK_SIZE : tl .constexpr ,
1850):
19-
2051 row_start = tl .program_id (0 )
2152 row_idx = row_start
2253
@@ -54,7 +85,6 @@ def _softmax_kernel_online(
5485 output_ptrs = output_row_start_ptr + col_offsets
5586 tl .store (output_ptrs , softmax_output , mask = mask )
5687
57-
5888def softmax (x ):
5989 """
6090 Computes the row-wise softmax of a 2D input tensor.
@@ -73,17 +103,20 @@ def softmax(x):
73103 n_rows , n_cols = x .shape
74104
75105 MAX_FUSED_SIZE = 65536 // x .element_size ()
106+ print ("MAX_FUSED_SIZE: " , MAX_FUSED_SIZE )
76107 BLOCK_SIZE = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
108+ print ("BLOCK_SIZE: " , BLOCK_SIZE )
77109 y = torch .empty_like (x )
78110
79- waves_per_eu = 2
111+ waves_per_eu = 4 # 2
80112 num_warps = 8
81113 num_stages = 2
82114
83115 num_programs = n_rows
84116
85117 grid = lambda meta : (num_programs ,) # noqa: E731
86118 _softmax_kernel_online [grid ](
119+ # _softmax_kernel[grid](
87120 y ,
88121 x ,
89122 x .stride (0 ),
0 commit comments