Skip to content

Commit a1d868b

Browse files
avik-palwsmoses
andauthoredApr 13, 2025
feat: pass in new options to optimize-comm (#1172)
* feat: pass in new options to optimize-comm [skip ci] * feat: more configurable options * fix: ambiguity * wip * rm print --------- Co-authored-by: William Moses <[email protected]> Co-authored-by: William S. Moses <[email protected]>
1 parent d6d5546 commit a1d868b

File tree

4 files changed

+57
-19
lines changed

4 files changed

+57
-19
lines changed
 

‎src/CompileOptions.jl

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# TODO: make the other optimize options into a struct as well
2+
@kwdef struct OptimizeCommunicationOptions
3+
periodic_concat::Int = 0
4+
rotate_comm::Int = 0
5+
rotate_to_pad_comm::Int = 1
6+
wrap_comm::Int = 0
7+
extend_comm::Int = 0
8+
dus_to_pad_manual_comp_comm::Int = 0
9+
dus_to_pad_comm::Int = 1
10+
concat_two_operands_comm::Int = 0
11+
concat_to_pad_comm::Int = 1
12+
extend_to_pad_comm::Int = 1
13+
wrap_to_pad_comm::Int = 1
14+
end
15+
16+
function Base.String(options::OptimizeCommunicationOptions)
17+
return (
18+
"optimize-communication{" *
19+
join(["$(f)=$(getfield(options, f))" for f in fieldnames(typeof(options))], " ") *
20+
"}"
21+
)
22+
end

‎src/Compiler.jl

+28-18
Original file line numberDiff line numberDiff line change
@@ -1036,21 +1036,25 @@ function raising!(f, is_raising::Bool)
10361036
end
10371037
end
10381038

1039-
# TODO investigate which options need enable/disable
1040-
# const comm_pass = "optimize-communication{periodic_concat=1 rotate_comm=1 wrap_comm=1 dus_to_pad_manual_comp_comm=1 dus_to_pad_comm=0 concat_two_operands_comm=0 concat_to_pad_comm=0 extend_to_pad_comm=0 wrap_to_pad_comm=0 concat_two_dus_like=1 extend_dus_like=1}"
1041-
const comm_pass = "optimize-communication"
1042-
1043-
const optimize_comms_passes = (
1044-
# rotate handler presently broken (and handled okay presently), disabling for now
1045-
"enzyme-hlo-generate-td{patterns=lower_rotate;concat_to_onedim_dus;concat_to_onedim_dusslice}",
1046-
"transform-interpreter",
1047-
"enzyme-hlo-remove-transform",
1048-
comm_pass,
1049-
"enzyme-hlo-generate-td{patterns=lower_rotate;lower_wrap;lower_extend}",
1050-
"transform-interpreter",
1051-
"enzyme-hlo-remove-transform",
1052-
comm_pass,
1053-
)
1039+
function get_optimize_comms_passes(options::Bool)
1040+
options || return String[]
1041+
return get_optimize_comms_passes(Reactant.OptimizeCommunicationOptions())
1042+
end
1043+
1044+
function get_optimize_comms_passes(options::Reactant.OptimizeCommunicationOptions)
1045+
options_str = String(options)
1046+
res = [
1047+
"enzyme-hlo-generate-td{patterns=lower_rotate;concat_to_onedim_dus;concat_to_onedim_dusslice;concatreshape_to_onedim_dus}",
1048+
"transform-interpreter",
1049+
"enzyme-hlo-remove-transform",
1050+
options_str,
1051+
"enzyme-hlo-generate-td{patterns=lower_rotate;lower_wrap;lower_extend}",
1052+
"transform-interpreter",
1053+
"enzyme-hlo-remove-transform",
1054+
options_str,
1055+
]
1056+
return res
1057+
end
10541058

10551059
function compile_mlir!(
10561060
mod,
@@ -1065,6 +1069,7 @@ function compile_mlir!(
10651069
no_nan::Bool=false,
10661070
transpose_propagate::Symbol=:up,
10671071
reshape_propagate::Symbol=:up,
1072+
optimize_communications::Bool=true,
10681073
assert_nonallocating::Bool=false,
10691074
backend="gpu",
10701075
raise::Union{Bool,String}=false,
@@ -1459,7 +1464,7 @@ function compile_mlir!(
14591464
sym_visibility=MLIR.IR.attr(compiled_f, "private"),
14601465
)
14611466
fnbody = MLIR.IR.Block(
1462-
in_tys_padded, [MLIR.IR.Location() for _ in in_tys_padded]
1467+
in_tys_padded, [MLIR.IR.Location(MLIR.API.mlirValueGetLocation(MLIR.IR.argument(MLIR.IR.first_block(MLIR.IR.region(compiled_f, 1)), i))) for i in 1:length(linear_args)]
14631468
)
14641469
push!(MLIR.IR.region(func_with_padding, 1), fnbody)
14651470
MLIR.IR.activate!(fnbody)
@@ -1592,7 +1597,7 @@ function compile_mlir!(
15921597
join(
15931598
[
15941599
"sdy-close-shardings",
1595-
optimize_comms_passes...,
1600+
get_optimize_comms_passes(optimize_communications)...,
15961601
"xla-sdy-stablehlo-export-pipeline",
15971602
],
15981603
",",
@@ -1606,7 +1611,7 @@ function compile_mlir!(
16061611
[
16071612
"sdy-propagation-pipeline",
16081613
"sdy-close-shardings",
1609-
optimize_comms_passes...,
1614+
get_optimize_comms_passes(optimize_communications)...,
16101615
"xla-sdy-stablehlo-export-pipeline",
16111616
],
16121617
",",
@@ -1764,6 +1769,7 @@ macro code_hlo(args...)
17641769
:transpose_propagate => :(:up),
17651770
:reshape_propagate => :(:up),
17661771
:optimize_then_pad => true,
1772+
:optimize_communications => true,
17671773
)
17681774
compile_expr, (; compiled) = compile_call_expr(
17691775
__module__, compile_mlir, default_options, args...
@@ -1797,6 +1803,7 @@ macro code_mhlo(args...)
17971803
:transpose_propagate => :(:up),
17981804
:reshape_propagate => :(:up),
17991805
:optimize_then_pad => true,
1806+
:optimize_communications => true,
18001807
)
18011808
compile_expr, (; compiled) = compile_call_expr(
18021809
__module__, compile_xla, default_options, args...
@@ -1830,6 +1837,7 @@ macro code_xla(args...)
18301837
:transpose_propagate => :(:up),
18311838
:reshape_propagate => :(:up),
18321839
:optimize_then_pad => true,
1840+
:optimize_communications => true,
18331841
)
18341842
compile_expr, (; compiled) = compile_call_expr(
18351843
__module__, compile_xla, default_options, args...
@@ -1863,6 +1871,7 @@ macro compile(args...)
18631871
:transpose_propagate => :(:up),
18641872
:reshape_propagate => :(:up),
18651873
:optimize_then_pad => true,
1874+
:optimize_communications => true,
18661875
)
18671876
return esc(first(compile_call_expr(__module__, compile, default_options, args...)))
18681877
end
@@ -1885,6 +1894,7 @@ macro jit(args...)
18851894
:transpose_propagate => :(:up),
18861895
:reshape_propagate => :(:up),
18871896
:optimize_then_pad => true,
1897+
:optimize_communications => true,
18881898
)
18891899
compile_expr, (; compiled, args) = compile_call_expr(
18901900
__module__, compile, default_options, args...

‎src/Reactant.jl

+4
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}
179179

180180
include("ControlFlow.jl")
181181
include("Tracing.jl")
182+
183+
include("CompileOptions.jl")
184+
export OptimizeCommunicationOptions
185+
182186
include("Compiler.jl")
183187

184188
include("Overlay.jl")

‎src/Sharding.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,9 @@ function NamedSharding(sharding::Replicated, ndims::Int)
732732
return NamedSharding(sharding.mesh, ntuple(Returns(nothing), ndims))
733733
end
734734

735-
function (sharding::Replicated)(client::XLA.AbstractClient, device, x)
735+
function (sharding::Replicated)(
736+
client::XLA.AbstractClient, device, x::Union{AbstractArray,Number}
737+
)
736738
return (NamedSharding(sharding, ndims(x)))(client, device, x)
737739
end
738740

0 commit comments

Comments
 (0)