Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Nov 30, 2025

fixes #1675

  • loop info
    • affine index propagation
    • remove the loop offset map
  • loop to batch
    • general affine indexing support
    • step != 1
    • negative scaling
    • iota indexing with scaled index
    • constant iteration argument
  • while is copy simplify
    • support affine indexing
    • step != 1
    • negative scaling
  • Miscellaneous issues
    • fix regression in nested_loop_conv test
    • regression in other testing
    • infinite compile time
  • Multiple indices depend on ind var

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EnzymeJAX Benchmarks

Benchmark suite Current: 858c436 Previous: 8df9e47 Ratio
scatter_sum / JaX / cpu / Primal 0.00000435670900005789 s 0.000004262526800084742 s 1.02
scatter_sum / JaXPipe / cpu / Primal 0.000004330562999894028 s 0.000004347323699948902 s 1.00
scatter_sum / JaX / tpu / Primal 0.0001417065681001 s 0.0001549653354999 s 0.91
scatter_sum / JaXPipe / tpu / Primal 0.0001430293372002 s 0.0001542208233999 s 0.93

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal force-pushed the ap/generalize_to_affine_indexing branch 5 times, most recently from b426db5 to 6be8966 Compare December 2, 2025 22:35
@avik-pal avik-pal force-pushed the ap/generalize_to_affine_indexing branch 2 times, most recently from 820ebf7 to 29a0b1a Compare December 4, 2025 15:56
@avik-pal avik-pal marked this pull request as ready for review December 4, 2025 16:00
@avik-pal avik-pal requested a review from wsmoses December 4, 2025 16:00
@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 4, 2025

using Reactant

Reactant.Compiler.WHILE_UNROLL_THRESHOLD[] = 0

function negative_mul_indexing(x)
    y = similar(x)
    fill!(y, 0)
    @trace for i in 1:2:6
        @allowscalar y[2:6, 20 - 2i, 3:4] = cos.(x[1:5, 20 - 3i, 1:2])
    end
    return y
end

x = Reactant.to_rarray(rand(Float32, 6, 20, 5));

@code_hlo negative_mul_indexing(x)
module {
  func.func @main(%arg0: tensor<5x20x6xf32> {enzymexla.memory_effects = []}) -> tensor<5x20x6xf32> attributes {enzymexla.memory_effects = []} {
    %c = stablehlo.constant dense<[[2, 17, 1], [2, 13, 1], [2, 9, 1]]> : tensor<3x3xi32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<5x20x6xf32>
    %0 = stablehlo.slice %arg0 [0:2, 4:17:6, 0:5] : (tensor<5x20x6xf32>) -> tensor<2x3x5xf32>
    %1 = stablehlo.reverse %0, dims = [1] : tensor<2x3x5xf32>
    %2 = stablehlo.cosine %1 : tensor<2x3x5xf32>
    %3 = "stablehlo.scatter"(%cst, %c, %2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2], inserted_window_dims = [1], scatter_dims_to_operand_dims = [0, 1, 2], index_vector_dim = 1>, unique_indices = true}> ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      stablehlo.return %arg2 : tensor<f32>
    }) : (tensor<5x20x6xf32>, tensor<3x3xi32>, tensor<2x3x5xf32>) -> tensor<5x20x6xf32>
    return %3 : tensor<5x20x6xf32>
  }
}

@avik-pal avik-pal force-pushed the ap/generalize_to_affine_indexing branch 13 times, most recently from 3ffed93 to a47b63c Compare December 6, 2025 17:05
@avik-pal avik-pal force-pushed the ap/generalize_to_affine_indexing branch 3 times, most recently from 70a7f58 to ae4ea37 Compare December 6, 2025 22:21
@avik-pal avik-pal force-pushed the ap/generalize_to_affine_indexing branch from ae4ea37 to b1e11d2 Compare December 6, 2025 22:22
@avik-pal avik-pal force-pushed the ap/generalize_to_affine_indexing branch from b1e11d2 to 1a2288d Compare December 7, 2025 00:55
@wsmoses
Copy link
Member

wsmoses commented Dec 7, 2025

the jax integration tests complain

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 7, 2025

looking into them

@avik-pal avik-pal force-pushed the ap/generalize_to_affine_indexing branch 2 times, most recently from 668ede7 to 10f47ab Compare December 9, 2025 02:17
feat: generalize auto-batching to support affine indexing

test: affine loop

feat: support constant loop arguments

feat: iota indexing with affine index

refactor: cleanup while_is_copy_simplify

feat: generalize while is copy

fix: use correct APIs

chore: run fmt

fix: add check

feat: step != 1

feat: generalize while is copy simplification pattern

fix: add additional checks

feat: support negative scale

fix: infinite compile in dus_pad

fix: test

test: neg scale

fix: assertion

fix: transpose_reshape
@avik-pal avik-pal force-pushed the ap/generalize_to_affine_indexing branch from 10f47ab to d598cd1 Compare December 9, 2025 03:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

While Loop raising with multiple induction var indexing

3 participants