Is your feature request related to a problem? Please describe.
Grouped operations with grouped offsets (e.g., Group Layer Norm) are hard to express efficiently in Helion.
I was writing a helion kernel where weights and biases are shared across groups of rows (e.g., 3 groups), with group membership determined by a group_offset array.
Describe the solution you'd like
An hl.grid outerloop with a hl.tile inner loop, where hl.grid loops over the group offset so each tile inside uses the same share weights and biases matrices. Ideally, not every tile inside of the same grid has to be on the same SM.
Describe alternatives you've considered
I tried two approaches and neither worked well:
hl.grid(G) + hl.tile inside: With a small number of groups (G=3), I think this only launched on 3 SMs, massively underutilizing the GPU. Performance was very slow.
I also tried hl.tile(M) with indirect indexing (embedding-style weight[group_offsets[tile_m], :] to look up per-row weights/biases). I think helion made it hard assign index groups inside of each tile, and the solution felt counter-intuitive given that hl.grid seems to be made to iterate over scalar integer index (ex. group offsets)
Additional context
Is your feature request related to a problem? Please describe.
Grouped operations with grouped offsets (e.g., Group Layer Norm) are hard to express efficiently in Helion.
I was writing a helion kernel where weights and biases are shared across groups of rows (e.g., 3 groups), with group membership determined by a group_offset array.
Describe the solution you'd like
An
hl.gridouterloop with ahl.tileinner loop, where hl.grid loops over the group offset so each tile inside uses the same share weights and biases matrices. Ideally, not every tile inside of the same grid has to be on the same SM.Describe alternatives you've considered
I tried two approaches and neither worked well:
hl.grid(G) + hl.tile inside: With a small number of groups (G=3), I think this only launched on 3 SMs, massively underutilizing the GPU. Performance was very slow.
I also tried hl.tile(M) with indirect indexing (embedding-style weight[group_offsets[tile_m], :] to look up per-row weights/biases). I think helion made it hard assign index groups inside of each tile, and the solution felt counter-intuitive given that
hl.gridseems to be made to iterate over scalar integer index (ex. group offsets)Additional context