Skip to content

Conversation

Copy link

Copilot AI commented Nov 25, 2025

Implements optimization to simplify consecutive stablehlo.reverse operations. When reversing the same dimension twice, the operations cancel out. This addresses chains of reverses appearing in autodiff-generated code (ref #1646).

Changes

  • New ReverseReverse pattern in EnzymeHLOOpt.cpp: Computes symmetric difference of dimensions between consecutive reverses
    • Dimensions in both → cancel out (identity)
    • Dimensions in only one → remain
  • Pattern registration in TransformOps.td as reverse_reverse
  • Test coverage in reversereverse.mlir

Example

// Before
%0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32>
%1 = stablehlo.reverse %0, dims = [1, 2] : tensor<8x4x3xf32>

// After (dims [1] cancel, [0] and [2] remain)
%0 = stablehlo.reverse %arg0, dims = [0, 2] : tensor<8x4x3xf32>

When all dimensions cancel:

// Before
%0 = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32>
%1 = stablehlo.reverse %0, dims = [0] : tensor<8x4xf32>

// After: no reverse ops, returns %arg0 directly

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • releases.bazel.build
    • Triggering command: /usr/local/lib/node_modules/@bazel/bazelisk/bazelisk-linux_amd64 /usr/local/lib/node_modules/@bazel/bazelisk/bazelisk-linux_amd64 build //:enzymexlamlir-opt (dns block)
    • Triggering command: /usr/local/lib/node_modules/@bazel/bazelisk/bazelisk-linux_amd64 /usr/local/lib/node_modules/@bazel/bazelisk/bazelisk-linux_amd64 version (dns block)

If you need me to access, download, or install something from one of these locations, you can either:

Original prompt

This section details on the original issue you should resolve

<issue_title>Reverse(reverse) -> noop</issue_title>
<issue_description>see #1646</issue_description>

<agent_instructions>Write an optimization in enzymehloopt.cpp like all the others in the file and corresponding utility files, that implements the optimization reverse(reverse(x)) -> x, where legal</agent_instructions>

Comments on the Issue (you are @copilot in this section)


💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.

Copilot AI changed the title [WIP] Add optimization for reverse(reverse(x)) to noop Add reverse(reverse(x)) -> x optimization pattern Nov 25, 2025
Copilot AI requested a review from wsmoses November 25, 2025 05:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Reverse(reverse) -> noop

2 participants