We are currently optimizing the LayerNorm kernel for AMD_350x. Despite running the Helion autotuner multiple times, the Helion LayerNorm kernel remains less performant than Inductor on this hardware for some shapes, especially for (1, 2304, 36864) , helion is 2x worser compared to inductor: helion (0.72) v.s inductor(0.25)
Re-autotuning the kernel did not yield any improvement.
# src[layer_norm.py:2377]: if bias is not None:
# src[layer_norm.py:2378]: assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {n}"
if bias is not None:
# src[layer_norm.py:2378]: assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {n}"
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {n}'
# src[layer_norm.py:2379]: assert len(normalized_shape) == 1, (
# src[layer_norm.py:2380]: "Helion layer norm only supports 1D layer norm currently"
# src[layer_norm.py:2381]: )
assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
# src[layer_norm.py:2382]: assert normalized_shape[0] == n, (
# src[layer_norm.py:2383]: f"normalized shape mismatch {normalized_shape[0]} != {n}"
# src[layer_norm.py:2384]: )
assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}'
# src[layer_norm.py:2385]: out = torch.empty([m, n], dtype=x.dtype, device=x.device)
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
# src[layer_norm.py:2386]: mean = torch.empty([m], dtype=torch.float32, device=x.device)
mean = torch.empty([m], dtype=torch.float32, device=x.device)
# src[layer_norm.py:2387]: rstd = torch.empty([m], dtype=torch.float32, device=x.device)
rstd = torch.empty([m], dtype=torch.float32, device=x.device)
# src[layer_norm.py:2389]: for tile_m in hl.tile(m):
_RDIM_SIZE_1 = 65536
# src[layer_norm.py:2389]: for tile_m in hl.tile(m):
# src[layer_norm.py:2390]: acc = x[tile_m, :].to(torch.float32)
# src[layer_norm.py:2391]: # Compute mean
# src[layer_norm.py:2389-2409]: ...
_launcher(_helion_layer_norm_fwd, (m,), x, weight, bias, out, mean, rstd, mean.size(0), bias.stride(0), mean.stride(0), out.stride(0), out.stride(1), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), eps, _RDIM_SIZE_1, num_warps=8, num_stages=2, matrix_instr_nonkdim=32, waves_per_eu=4)
# src[layer_norm.py:2410]: return out, mean, rstd
return (out, mean, rstd)
def call():
from torch._dynamo.testing import rand_strided
# src[layer_norm.py:2445]: num_blocks = (x.size(0) + m_block - 1) // m_block
num_blocks = (x.size(0) + m_block - 1) // m_block
# src[layer_norm.py:2446]: grad_weight_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
grad_weight_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
# src[layer_norm.py:2447]: grad_bias_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
grad_bias_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
# src[layer_norm.py:2449]: for mb_cta in hl.tile(x.size(0), block_size=m_block):
_NUM_SM = helion.runtime.get_num_sm(grad_out.device)
_RDIM_SIZE_1 = 65536
# src[layer_norm.py:2449]: for mb_cta in hl.tile(x.size(0), block_size=m_block):
# src[layer_norm.py:2450]: grad_w_acc = weight.new_zeros(n, dtype=torch.float32)
# src[layer_norm.py:2451]: if compute_bias_grad:
# src[layer_norm.py:2449-2474]: ...
_launcher(_helion_layer_norm_bwd, (_NUM_SM * 2,), x, weight, grad_out, mean, rstd, grad_x, grad_weight_blocks, grad_bias_blocks, grad_bias_blocks.size(0), grad_x.size(0), rstd.size(0), x.size(0), grad_bias_blocks.stride(0), grad_bias_blocks.stride(1), grad_out.stride(0), grad_out.stride(1), grad_weight_blocks.stride(0), grad_weight_blocks.stride(1), grad_x.stride(0), grad_x.stride(1), mean.stride(0), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), _NUM_SM, _RDIM_SIZE_1, num_warps=8, num_stages=4, matrix_instr_nonkdim=32, waves_per_eu=1)
# src[layer_norm.py:2476]: grad_weight = grad_weight_blocks.sum(0).to(weight.dtype)
grad_weight = grad_weight_blocks.sum(0).to(weight.dtype)
# src[layer_norm.py:2477]: if compute_bias_grad:
# src[layer_norm.py:2478]: grad_bias = grad_bias_blocks.sum(0).to(weight.dtype)
# src[layer_norm.py:2479]: return grad_x, grad_weight, grad_bias
if True:
# src[layer_norm.py:2478]: grad_bias = grad_bias_blocks.sum(0).to(weight.dtype)
grad_bias = grad_bias_blocks.sum(0).to(weight.dtype)
# src[layer_norm.py:2479]: return grad_x, grad_weight, grad_bias
return (grad_x, grad_weight, grad_bias)
# src[layer_norm.py:2480]: return grad_x, grad_weight, None
return (grad_x, grad_weight, None)
def call():
from torch._dynamo.testing import rand_strided
# src[layer_norm.py:2414]: def layer_norm_bwd(
# src[layer_norm.py:2415]: grad_out: torch.Tensor,
# src[layer_norm.py:2416]: x: torch.Tensor,
# src[layer_norm.py:2414-2480]: ...
grad_out = rand_strided(size=(2304, 36864), stride=(36864, 1), dtype=torch.bfloat16, device='cuda:0')
x = rand_strided(size=(2304, 36864), stride=(36864, 1), dtype=torch.bfloat16, device='cuda:0')
mean = rand_strided(size=(2304,), stride=(1,), dtype=torch.float32, device='cuda:0')
rstd = rand_strided(size=(2304,), stride=(1,), dtype=torch.float32, device='cuda:0')
weight = rand_strided(size=(36864,), stride=(1,), dtype=torch.bfloat16, device='cuda:0')
compute_bias_grad = True
layer_norm_bwd(grad_out, x, mean, rstd, weight, compute_bias_grad)
if __name__ == '__main__':
call()
# AOT ID: ['4_forward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /var/tmp/torchinductor_mengjiao/bd/cbd42ow37dkvyysrib5m7gbpaewfokss4zdpnauaftk4chzdltpt.py
# Topologically Sorted Source Nodes: [layer_norm], Original ATen: [aten.native_layer_norm]
# Source node to ATen node mapping:
# layer_norm => add, add_1, convert_element_type, convert_element_type_1, getitem_1, mul, mul_1, rsqrt, sub, var_mean
# Graph fragment:
# %primals_1 : Tensor "bf16[1, 2304, 36864][84934656, 36864, 1]cuda:0" = PlaceHolder[target=primals_1]
# %buf0 : Tensor "f32[1, 2304, 1][2304, 1, 2304]cuda:0" = PlaceHolder[target=buf0]
# %buf1 : Tensor "f32[1, 2304, 1][2304, 1, 2304]cuda:0" = PlaceHolder[target=buf1]
# %buf2 : Tensor "f32[1, 2304, 1][2304, 1, 2304]cuda:0" = PlaceHolder[target=buf2]
# %rsqrt : Tensor "f32[1, 2304, 1][2304, 1, 1]cuda:0" = PlaceHolder[target=rsqrt]
# %primals_2 : Tensor "bf16[36864][1]cuda:0" = PlaceHolder[target=primals_2]
# %primals_3 : Tensor "bf16[36864][1]cuda:0" = PlaceHolder[target=primals_3]
# %convert_element_type : Tensor "f32[1, 2304, 36864][84934656, 36864, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_1, torch.float32), kwargs = {})
# %var_mean : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type, [2]), kwargs = {correction: 0, keepdim: True})
# %getitem_1 : Tensor "f32[1, 2304, 1][2304, 1, 1]cuda:0"[num_users=2] = call_function[target=operator.getitem](args = (%var_mean, 1), kwargs = {})
# %add : Tensor "f32[1, 2304, 1][2304, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem, 1e-05), kwargs = {})
# %rsqrt : Tensor "f32[1, 2304, 1][2304, 1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
# %sub : Tensor "f32[1, 2304, 36864][84934656, 36864, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem_1), kwargs = {})
# %mul : Tensor "f32[1, 2304, 36864][84934656, 36864, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %rsqrt), kwargs = {})
# %mul_1 : Tensor "f32[1, 2304, 36864][84934656, 36864, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %primals_2), kwargs = {})
# %add_1 : Tensor "f32[1, 2304, 36864][84934656, 36864, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_1, %primals_3), kwargs = {})
# %convert_element_type_1 : Tensor "bf16[1, 2304, 36864][84934656, 36864, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_1, torch.bfloat16), kwargs = {})
# return %buf0,%buf1,%buf2,%rsqrt,%convert_element_type_1
triton_red_fused_native_layer_norm_0 = async_compile.triton('triton_red_fused_native_layer_norm_0', '''
import triton
import triton.language as tl
import triton.language.extra.tlx as tlx # noqa: F401
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 4096, 'r0_': 65536},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_out_ptr1': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='hip', index=0, multi_processor_count=256, cc='gfx950', major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=64), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16], ['tt.pointer_range', 32]], (1,): [['tt.divisibility', 16], ['tt.pointer_range', 32]], (2,): [['tt.divisibility', 16], ['tt.pointer_range', 32]], (3,): [['tt.divisibility', 16], ['tt.pointer_range', 32]], (4,): [['tt.divisibility', 16], ['tt.pointer_range', 32]], (5,): [['tt.divisibility', 16], ['tt.pointer_range', 32]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_0', 'mutated_arg_names': ['in_out_ptr0', 'in_out_ptr1'], 'optimize_mem': False, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 3, 'num_reduction': 2, 'backend_hash': 'E3565433EC5266B9E41D740243879373C6297D57D928F6BAE46DD3A44A623A50', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 32, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'is_hip': True, 'is_fbcode': True}
)
@triton.jit
def triton_red_fused_native_layer_norm_0(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 2304
r0_numel = 36864
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 36864*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
tmp4 = _tmp3 + tmp2
_tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
tmp3 = tl.sum(_tmp3, 1)[:, None]
tmp5 = tl.full([1, 1], 36864.0, tl.float32)
tmp6 = (tmp3 / tmp5)
tl.debug_barrier()
tl.store(in_out_ptr0 + (x0), tmp6, xmask)
_tmp12 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp7 = tl.load(in_ptr0 + (r0_1 + 36864*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp8 - tmp6
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(r0_mask & xmask, tmp13, _tmp12)
tmp12 = tl.sum(_tmp12, 1)[:, None]
tmp14 = tl.full([1, 1], 36864.0, tl.float32)
tmp15 = (tmp12 / tmp14)
tmp16 = tl.full([1, 1], 1e-05, tl.float32)
tmp17 = tmp15 + tmp16
tmp18 = tl.rsqrt(tmp17)
tl.debug_barrier()
tl.store(in_out_ptr1 + (x0), tmp18, xmask)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp19 = tl.load(in_ptr0 + (r0_1 + 36864*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp26 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp20 - tmp6
tmp22 = tmp21 * tmp18
tmp24 = tmp23.to(tl.float32)
tmp25 = tmp22 * tmp24
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp25 + tmp27
tmp29 = tmp28.to(tl.float32)
tl.store(out_ptr0 + (r0_1 + 36864*x0), tmp29, r0_mask & xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2, primals_3 = args
args.clear()
assert_size_stride(primals_1, (1, 2304, 36864), (84934656, 36864, 1))
assert_size_stride(primals_2, (36864, ), (1, ))
assert_size_stride(primals_3, (36864, ), (1, ))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1, 2304, 1), (2304, 1, 2304), torch.float32)
buf1 = buf0; del buf0 # reuse
buf2 = empty_strided_cuda((1, 2304, 1), (2304, 1, 2304), torch.float32)
buf3 = reinterpret_tensor(buf2, (1, 2304, 1), (2304, 1, 1), 0); del buf2 # reuse
buf4 = empty_strided_cuda((1, 2304, 36864), (84934656, 36864, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [layer_norm], Original ATen: [aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused_native_layer_norm_0.run(buf1, buf3, primals_1, primals_2, primals_3, buf4, 2304, 36864, stream=stream0)
del primals_3
return (buf4, primals_1, primals_2, reinterpret_tensor(buf1, (1, 2304, 1), (2304, 1, 1), 0), buf3, )
def get_args():
from torch._dynamo.testing import rand_strided
primals_1 = rand_strided((1, 2304, 36864), (84934656, 36864, 1), device='cuda:0', dtype=torch.bfloat16)
primals_2 = rand_strided((36864, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
primals_3 = rand_strided((36864, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
return [primals_1, primals_2, primals_3]
def benchmark_compiled_module(args, times=10, repeat=10):
from torch._inductor.utils import print_performance
fn = lambda: call(list(args))
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
args = get_args()
compiled_module_main('helion_layernorm,inductor_layernorm', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
We are currently optimizing the LayerNorm kernel for AMD_350x. Despite running the Helion autotuner multiple times, the Helion LayerNorm kernel remains less performant than Inductor on this hardware for some shapes, especially for (1, 2304, 36864) , helion is 2x worser compared to inductor: helion (0.72) v.s inductor(0.25)
Re-autotuning the kernel did not yield any improvement.
Performance Results
Backward pass
forward pass
Versions
Triton Version: 3.5.0+fb
ROCm: 7.0.2.1 (using -m rocm70 mode)
helion generated code:
inductor generated code: