Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Dec 7, 2025

  • add a pass to simplify dot(bcast, bcast)
  • concatreshapereduce simplify

@avik-pal avik-pal changed the title fix: convert mul -> reduce to dot_general fix: convert reduce(mul) to dot_general Dec 7, 2025
// CHECK-NEXT: %0 = stablehlo.multiply %arg1, %arg0 : tensor<32x32xf64>
// CHECK-NEXT: %1 = stablehlo.reduce(%0 init: %cst) applies stablehlo.add across dimensions = [1, 0] : (tensor<32x32xf64>, tensor<f64>) -> tensor<f64>
// CHECK-NEXT: return %1 : tensor<f64>
// CHECK-NEXT: %0 = stablehlo.dot_general %arg1, %arg0, contracting_dims = [1, 0] x [1, 0] : (tensor<32x32xf64>, tensor<32x32xf64>) -> tensor<f64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case here is unclear to me whether it would be an improvement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the one hand, it is for sure fewer instructions/simpler and also the definition of a dot product. On the other hand, if it calls gemm or related clearly thats worse.

something to ponder/benchmark

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA is smart about this and doesn't call gemm unless it really has to

julia> fn(x, y) = Reactant.Ops.dot_general(x, y; contracting_dimensions=([1, 2], [1, 2]))
fn (generic function with 2 methods)

julia> @code_hlo fn(Reactant.to_rarray(rand(Float32, 4, 5)), Reactant.to_rarray(rand(Float32, 4, 5)))
module @reactant_fn attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<5x4xf32> {enzymexla.memory_effects = []}, %arg1: tensor<5x4xf32> {enzymexla.memory_effects = []}) -> tensor<f32> attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1, 0] x [1, 0], precision = [DEFAULT, DEFAULT] : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
}

julia> @code_xla fn(Reactant.to_rarray(rand(Float32, 4, 5)), Reactant.to_rarray(rand(Float32, 4, 5)))
HloModule reactant_fn, is_scheduled=true, entry_computation_layout={(f32[5,4]{1,0}, f32[5,4]{1,0})->f32[]}, frontend_attributes={fingerprint_before_lhs="66ce375ff85a3f8ea88cdf73b6088680"}

FileNames
1 "/mnt2/avik-pal/reactant/Reactant.jl/src/Ops.jl"

FunctionNames
1 "dot_general"

FileLocations
1 {file_name_id=1 function_name_id=1 line=944 end_line=944 column=0 end_column=0}

StackFrames
1 {file_location_id=1 parent_frame_id=1}


%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[] parameter(1)
  %scalar_lhs = f32[] parameter(0)
  ROOT %add.1 = f32[] add(%scalar_lhs, %scalar_rhs)
}

%fused_reduce (param_0.2: f32[5,4], param_1.3: f32[5,4]) -> f32[] {
  %param_0.2 = f32[5,4]{1,0} parameter(0)
  %param_1.3 = f32[5,4]{1,0} parameter(1)
  %multiply.1.1 = f32[5,4]{1,0} multiply(%param_0.2, %param_1.3), metadata={op_name="dot_general" stack_frame_id=1}
  %bitcast.2 = f32[20]{0} bitcast(%multiply.1.1), metadata={op_name="dot_general" stack_frame_id=1}
  %constant_2 = f32[] constant(0)
  ROOT %reduce.4 = f32[] reduce(%bitcast.2, %constant_2), dimensions={0}, to_apply=%scalar_add_computation, metadata={op_name="dot_general" stack_frame_id=1}
}

ENTRY %main.1 (arg1__path___args__1__.1: f32[5,4], arg2__path___args__2__.1: f32[5,4]) -> f32[] {
  %arg2__path___args__2__.1 = f32[5,4]{1,0} parameter(1), metadata={op_name="arg2 (path=(:args, 2))"}
  %arg1__path___args__1__.1 = f32[5,4]{1,0} parameter(0), metadata={op_name="arg1 (path=(:args, 1))"}
  ROOT %loop_reduce_fusion = f32[] fusion(%arg1__path___args__1__.1, %arg2__path___args__2__.1), kind=kInput, calls=%fused_reduce, metadata={op_name="dot_general" stack_frame_id=1}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID","native_emitter_backend_config":{}}
}

@avik-pal avik-pal marked this pull request as ready for review December 8, 2025 05:37
@avik-pal avik-pal requested a review from wsmoses December 8, 2025 05:37
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EnzymeJAX Benchmarks

Benchmark suite Current: eede2db Previous: 5fa1fb9 Ratio
scatter_sum / JaX / cpu / Primal 0.000004240672000742052 s 0.000004337968700019701 s 0.98
scatter_sum / JaXPipe / cpu / Primal 0.000004327424000075553 s 0.000004346270799942431 s 1.00
scatter_sum / JaX / tpu / Primal 0.0001453522068011 s 0.0001424101054999 s 1.02
scatter_sum / JaXPipe / tpu / Primal 0.0001543985108 s 0.0001357899519 s 1.14

This comment was automatically generated by workflow using github-action-benchmark.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves HLO optimization by converting reduce(multiply(...)) patterns to more efficient dot_general operations. The key changes rename ReduceMulBroadcastToDotGeneral to ReduceMulToDotGeneral and generalize it to handle more cases beyond just broadcast patterns. Additionally, two new patterns (DotGeneralBroadcastInDim and DotGeneralBroadcastInDimSortDims) are introduced to simplify dot operations with broadcasted inputs by moving broadcasts after the dot_general when beneficial.

Key changes:

  • Generalized reduce(multiply) to dot_general conversion to work on any multiply operation, not just broadcasts
  • Added patterns to optimize dot_general operations with broadcasted operands
  • Updated auto-batching infrastructure to support intermediate reshape operations

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
test/lit_tests/reducetranspose.mlir Updated test expectations for reduce-to-dot_general optimization with bug in return statement
test/lit_tests/reducereshape.mlir Updated test expectations with incorrect variable reference bug
test/lit_tests/reduce_mul_broadcast_to_dot_general.mlir Expanded test cases for the generalized pattern with new examples
test/lit_tests/diffrules/stablehlo/while4.mlir Updated variable numbering due to constant elimination
test/lit_tests/diagonal_dotgeneral.mlir Updated expectations to use dot_general instead of multiply+reduce
test/lit_tests/concatreduce3.mlir Added auto-batching pass and updated expectations for dot_general optimization
test/lit_tests/autobatching/reduce_loop.mlir Updated expectations for batched dot_general with transpose
test/lit_tests/autobatching/concatreshapedotgeneral.mlir Updated broadcast dimensions for optimized dot_general
src/enzyme_ad/jax/primitives.py Renamed pattern from reduce_mul_broadcast_to_dot_general to reduce_mul_to_dot_general and added two new patterns
src/enzyme_ad/jax/TransformOps/TransformOps.td Renamed pattern definition and added two new pattern definitions
src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp Generalized ReduceMulToDotGeneral implementation and added DotGeneralBroadcastInDim patterns with minor commented code cleanup needed
src/enzyme_ad/jax/Passes/AutoBatching.h Updated function signatures to support intermediate insertions
src/enzyme_ad/jax/Passes/AutoBatching.cpp Enhanced auto-batching to handle intermediate reshapes and broadcasts with TypeSwitch pattern

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 9, 2025

tests should now pass

@avik-pal avik-pal force-pushed the ap/fix_bcast_mul_to_dot_general branch from 16fff85 to eede2db Compare December 9, 2025 02:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants