Skip to content

Commit 2297685

Browse files
committed
feat: try collapsing the ops again
1 parent 9c9b232 commit 2297685

File tree

1 file changed

+37
-5
lines changed

1 file changed

+37
-5
lines changed

src/Compiler.jl

+37-5
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,11 @@ end
388388

389389
# Optimization passes via transform dialect
390390
function optimization_passes(;
391-
no_nan::Bool=false, sroa::Bool=false, inline::Bool=true, transpose_propagate::Symbol=:up
391+
no_nan::Bool=false,
392+
sroa::Bool=false,
393+
inline::Bool=true,
394+
transpose_propagate::Symbol=:up,
395+
reshape_propagate::Symbol=:up,
392396
)
393397
transform_passes_list = [
394398
"patterns=compare_op_canon<16>",
@@ -485,7 +489,6 @@ function optimization_passes(;
485489
"bin_broadcast_splat_subtract<1>",
486490
"bin_broadcast_splat_div<1>",
487491
"bin_broadcast_splat_mul<1>",
488-
"reshape_iota<16>",
489492
"slice_reshape_slice<1>",
490493
"dot_general_simplify<16>",
491494
"transpose_simplify<16>",
@@ -528,7 +531,6 @@ function optimization_passes(;
528531
"convolution_transpose<1>",
529532
"convert_convert_float<1>",
530533
"concat_to_pad<1>",
531-
"concat_appending_reshape<1>",
532534
"reshape_iota<1>",
533535
"broadcast_reduce<1>",
534536
"slice_dot_general<1>",
@@ -578,6 +580,14 @@ function optimization_passes(;
578580
"abs_positive_simplify",
579581
]
580582

583+
if reshape_propagate === :up
584+
append!(transform_passes_list, ["concat_appending_reshape"])
585+
elseif reshape_propagate === :down
586+
append!(transform_passes_list, ["slice_reshape"])
587+
else
588+
error("Invalid value for reshape_propagate. Must be :up or :down.")
589+
end
590+
581591
if transpose_propagate === :up
582592
append!(
583593
transform_passes_list,
@@ -816,6 +826,7 @@ function compile_mlir!(
816826
shardy_passes::Symbol=:to_mhlo_shardings, # :none | :to_mhlo_shardings
817827
no_nan::Bool=false,
818828
transpose_propagate::Symbol=:up,
829+
reshape_propagate::Symbol=:up,
819830
assert_nonallocating::Bool=false,
820831
backend="gpu",
821832
fn_kwargs=(),
@@ -925,8 +936,12 @@ function compile_mlir!(
925936
jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
926937
end
927938

928-
opt_passes = optimization_passes(; no_nan, sroa=true, transpose_propagate)
929-
opt_passes2 = optimization_passes(; no_nan, sroa=false, transpose_propagate)
939+
opt_passes = optimization_passes(;
940+
no_nan, sroa=true, transpose_propagate, reshape_propagate
941+
)
942+
opt_passes2 = optimization_passes(;
943+
no_nan, sroa=false, transpose_propagate, reshape_propagate
944+
)
930945

931946
raise_passes = if raise isa String
932947
# Raising passes were specified
@@ -1088,6 +1103,18 @@ function compile_mlir!(
10881103
error("Invalid optimize option: $(Meta.quot(optimize))")
10891104
end
10901105

1106+
if optimize (:none, :just_batch, :canonicalize) &&
1107+
(transpose_propagate == :up || reshape_propagate == :up)
1108+
# We tried propagating reshapes and transposes up. If at this point we are left with
1109+
# them, we propagate them down to minimize the number of Ops in the IR.
1110+
run_pass_pipeline!(
1111+
mod,
1112+
optimization_passes(;
1113+
transpose_propagate::Symbol=:down, reshape_propagate::Symbol=:down
1114+
),
1115+
)
1116+
end
1117+
10911118
# shardy passes
10921119
use_shardy_partitioner = false
10931120
result_shardings = missing
@@ -1260,6 +1287,7 @@ macro code_hlo(args...)
12601287
:assert_nonallocating => false,
12611288
:donated_args => :(:auto),
12621289
:transpose_propagate => :(:up),
1290+
:reshape_propagate => :(:up),
12631291
)
12641292
compile_expr, (; compiled) = compile_call_expr(
12651293
__module__, compile_mlir, default_options, args...
@@ -1291,6 +1319,7 @@ macro code_mhlo(args...)
12911319
:assert_nonallocating => false,
12921320
:donated_args => :(:auto),
12931321
:transpose_propagate => :(:up),
1322+
:reshape_propagate => :(:up),
12941323
)
12951324
compile_expr, (; compiled) = compile_call_expr(
12961325
__module__, compile_xla, default_options, args...
@@ -1322,6 +1351,7 @@ macro code_xla(args...)
13221351
:assert_nonallocating => false,
13231352
:donated_args => :(:auto),
13241353
:transpose_propagate => :(:up),
1354+
:reshape_propagate => :(:up),
13251355
)
13261356
compile_expr, (; compiled) = compile_call_expr(
13271357
__module__, compile_xla, default_options, args...
@@ -1353,6 +1383,7 @@ macro compile(args...)
13531383
:serializable => false,
13541384
:donated_args => :(:auto),
13551385
:transpose_propagate => :(:up),
1386+
:reshape_propagate => :(:up),
13561387
)
13571388
return esc(first(compile_call_expr(__module__, compile, default_options, args...)))
13581389
end
@@ -1373,6 +1404,7 @@ macro jit(args...)
13731404
:assert_nonallocating => false,
13741405
:donated_args => :(:auto),
13751406
:transpose_propagate => :(:up),
1407+
:reshape_propagate => :(:up),
13761408
)
13771409
compile_expr, (; compiled, args) = compile_call_expr(
13781410
__module__, compile, default_options, args...

0 commit comments

Comments
 (0)