2424# ===-----------------------------------------------------------------------===#
2525
2626
27- @tl .constexpr_function
27+ @gluon .constexpr_function
2828def get_tmem_32x32b_reg_layout (instr_shape , shape , num_warps ):
2929 assert len (shape ) == 2 , "expected a 2D tensor"
3030 assert num_warps in [4 , 8 ], "expected 4 or 8 warps"
@@ -60,15 +60,15 @@ def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
6060 )
6161
6262
63- @tl .constexpr_function
63+ @gluon .constexpr_function
6464def get_mma_instr_shape (shape , element_ty ):
6565 m = 128 if shape [0 ] >= 128 else 64
6666 n = 256 if shape [1 ] >= 256 else shape [1 ]
6767 k = 256 // element_ty .primitive_bitwidth
6868 return (m , n , k )
6969
7070
71- @tl .constexpr_function
71+ @gluon .constexpr_function
7272def get_nvmma_layout (shape , element_ty , order = [1 , 0 ], fp4_padded = False ):
7373 packing_factor = 2 if fp4_padded else 1
7474
@@ -100,7 +100,7 @@ def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
100100 )
101101
102102
103- @tl .constexpr_function
103+ @gluon .constexpr_function
104104def get_mma_reg_layout (shape , num_warps , dtype = gl .float32 ):
105105 instr_shape = get_mma_instr_shape (shape , dtype )
106106 return get_tmem_32x32b_reg_layout (instr_shape , shape , num_warps )
@@ -111,7 +111,7 @@ def get_mma_reg_layout(shape, num_warps, dtype=gl.float32):
111111# ===-----------------------------------------------------------------------===#
112112
113113
114- @tl .constexpr_function
114+ @gluon .constexpr_function
115115def get_load_size_bytes (desc ):
116116 size = desc .dtype .primitive_bitwidth // 8
117117 for dim in desc .block_type .shape :
@@ -385,7 +385,7 @@ def __init__(self, channel, instr_shape, shape):
385385 def release (self ):
386386 self .channel .release ()
387387
388- @tl .constexpr_function
388+ @gluon .constexpr_function
389389 def get_reg_layout (self , num_warps ):
390390 return get_tmem_32x32b_reg_layout (self .instr_shape , self .shape , num_warps )
391391
0 commit comments