-
Notifications
You must be signed in to change notification settings - Fork 24
fix: convert reduce(mul) to dot_general
#1707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
reduce(mul) to dot_general
| // CHECK-NEXT: %0 = stablehlo.multiply %arg1, %arg0 : tensor<32x32xf64> | ||
| // CHECK-NEXT: %1 = stablehlo.reduce(%0 init: %cst) applies stablehlo.add across dimensions = [1, 0] : (tensor<32x32xf64>, tensor<f64>) -> tensor<f64> | ||
| // CHECK-NEXT: return %1 : tensor<f64> | ||
| // CHECK-NEXT: %0 = stablehlo.dot_general %arg1, %arg0, contracting_dims = [1, 0] x [1, 0] : (tensor<32x32xf64>, tensor<32x32xf64>) -> tensor<f64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case here is unclear to me whether it would be an improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on the one hand, it is for sure fewer instructions/simpler and also the definition of a dot product. On the other hand, if it calls gemm or related clearly thats worse.
something to ponder/benchmark
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XLA is smart about this and doesn't call gemm unless it really has to
julia> fn(x, y) = Reactant.Ops.dot_general(x, y; contracting_dimensions=([1, 2], [1, 2]))
fn (generic function with 2 methods)
julia> @code_hlo fn(Reactant.to_rarray(rand(Float32, 4, 5)), Reactant.to_rarray(rand(Float32, 4, 5)))
module @reactant_fn attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<5x4xf32> {enzymexla.memory_effects = []}, %arg1: tensor<5x4xf32> {enzymexla.memory_effects = []}) -> tensor<f32> attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1, 0] x [1, 0], precision = [DEFAULT, DEFAULT] : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<f32>
return %0 : tensor<f32>
}
}
julia> @code_xla fn(Reactant.to_rarray(rand(Float32, 4, 5)), Reactant.to_rarray(rand(Float32, 4, 5)))
HloModule reactant_fn, is_scheduled=true, entry_computation_layout={(f32[5,4]{1,0}, f32[5,4]{1,0})->f32[]}, frontend_attributes={fingerprint_before_lhs="66ce375ff85a3f8ea88cdf73b6088680"}
FileNames
1 "/mnt2/avik-pal/reactant/Reactant.jl/src/Ops.jl"
FunctionNames
1 "dot_general"
FileLocations
1 {file_name_id=1 function_name_id=1 line=944 end_line=944 column=0 end_column=0}
StackFrames
1 {file_location_id=1 parent_frame_id=1}
%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
%scalar_rhs = f32[] parameter(1)
%scalar_lhs = f32[] parameter(0)
ROOT %add.1 = f32[] add(%scalar_lhs, %scalar_rhs)
}
%fused_reduce (param_0.2: f32[5,4], param_1.3: f32[5,4]) -> f32[] {
%param_0.2 = f32[5,4]{1,0} parameter(0)
%param_1.3 = f32[5,4]{1,0} parameter(1)
%multiply.1.1 = f32[5,4]{1,0} multiply(%param_0.2, %param_1.3), metadata={op_name="dot_general" stack_frame_id=1}
%bitcast.2 = f32[20]{0} bitcast(%multiply.1.1), metadata={op_name="dot_general" stack_frame_id=1}
%constant_2 = f32[] constant(0)
ROOT %reduce.4 = f32[] reduce(%bitcast.2, %constant_2), dimensions={0}, to_apply=%scalar_add_computation, metadata={op_name="dot_general" stack_frame_id=1}
}
ENTRY %main.1 (arg1__path___args__1__.1: f32[5,4], arg2__path___args__2__.1: f32[5,4]) -> f32[] {
%arg2__path___args__2__.1 = f32[5,4]{1,0} parameter(1), metadata={op_name="arg2 (path=(:args, 2))"}
%arg1__path___args__1__.1 = f32[5,4]{1,0} parameter(0), metadata={op_name="arg1 (path=(:args, 1))"}
ROOT %loop_reduce_fusion = f32[] fusion(%arg1__path___args__1__.1, %arg2__path___args__2__.1), kind=kInput, calls=%fused_reduce, metadata={op_name="dot_general" stack_frame_id=1}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID","native_emitter_backend_config":{}}
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EnzymeJAX Benchmarks
| Benchmark suite | Current: eede2db | Previous: 5fa1fb9 | Ratio |
|---|---|---|---|
scatter_sum / JaX / cpu / Primal |
0.000004240672000742052 s |
0.000004337968700019701 s |
0.98 |
scatter_sum / JaXPipe / cpu / Primal |
0.000004327424000075553 s |
0.000004346270799942431 s |
1.00 |
scatter_sum / JaX / tpu / Primal |
0.0001453522068011 s |
0.0001424101054999 s |
1.02 |
scatter_sum / JaXPipe / tpu / Primal |
0.0001543985108 s |
0.0001357899519 s |
1.14 |
This comment was automatically generated by workflow using github-action-benchmark.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR improves HLO optimization by converting reduce(multiply(...)) patterns to more efficient dot_general operations. The key changes rename ReduceMulBroadcastToDotGeneral to ReduceMulToDotGeneral and generalize it to handle more cases beyond just broadcast patterns. Additionally, two new patterns (DotGeneralBroadcastInDim and DotGeneralBroadcastInDimSortDims) are introduced to simplify dot operations with broadcasted inputs by moving broadcasts after the dot_general when beneficial.
Key changes:
- Generalized
reduce(multiply)todot_generalconversion to work on any multiply operation, not just broadcasts - Added patterns to optimize
dot_generaloperations with broadcasted operands - Updated auto-batching infrastructure to support intermediate reshape operations
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| test/lit_tests/reducetranspose.mlir | Updated test expectations for reduce-to-dot_general optimization with bug in return statement |
| test/lit_tests/reducereshape.mlir | Updated test expectations with incorrect variable reference bug |
| test/lit_tests/reduce_mul_broadcast_to_dot_general.mlir | Expanded test cases for the generalized pattern with new examples |
| test/lit_tests/diffrules/stablehlo/while4.mlir | Updated variable numbering due to constant elimination |
| test/lit_tests/diagonal_dotgeneral.mlir | Updated expectations to use dot_general instead of multiply+reduce |
| test/lit_tests/concatreduce3.mlir | Added auto-batching pass and updated expectations for dot_general optimization |
| test/lit_tests/autobatching/reduce_loop.mlir | Updated expectations for batched dot_general with transpose |
| test/lit_tests/autobatching/concatreshapedotgeneral.mlir | Updated broadcast dimensions for optimized dot_general |
| src/enzyme_ad/jax/primitives.py | Renamed pattern from reduce_mul_broadcast_to_dot_general to reduce_mul_to_dot_general and added two new patterns |
| src/enzyme_ad/jax/TransformOps/TransformOps.td | Renamed pattern definition and added two new pattern definitions |
| src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | Generalized ReduceMulToDotGeneral implementation and added DotGeneralBroadcastInDim patterns with minor commented code cleanup needed |
| src/enzyme_ad/jax/Passes/AutoBatching.h | Updated function signatures to support intermediate insertions |
| src/enzyme_ad/jax/Passes/AutoBatching.cpp | Enhanced auto-batching to handle intermediate reshapes and broadcasts with TypeSwitch pattern |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
tests should now pass |
16fff85 to
eede2db
Compare
Uh oh!
There was an error while loading. Please reload this page.