388
388
389
389
# Optimization passes via transform dialect
390
390
function optimization_passes (;
391
- no_nan:: Bool = false , sroa:: Bool = false , inline:: Bool = true , transpose_propagate:: Symbol = :up
391
+ no_nan:: Bool = false ,
392
+ sroa:: Bool = false ,
393
+ inline:: Bool = true ,
394
+ transpose_propagate:: Symbol = :up ,
395
+ reshape_propagate:: Symbol = :up ,
392
396
)
393
397
transform_passes_list = [
394
398
" patterns=compare_op_canon<16>" ,
@@ -485,7 +489,6 @@ function optimization_passes(;
485
489
" bin_broadcast_splat_subtract<1>" ,
486
490
" bin_broadcast_splat_div<1>" ,
487
491
" bin_broadcast_splat_mul<1>" ,
488
- " reshape_iota<16>" ,
489
492
" slice_reshape_slice<1>" ,
490
493
" dot_general_simplify<16>" ,
491
494
" transpose_simplify<16>" ,
@@ -528,7 +531,6 @@ function optimization_passes(;
528
531
" convolution_transpose<1>" ,
529
532
" convert_convert_float<1>" ,
530
533
" concat_to_pad<1>" ,
531
- " concat_appending_reshape<1>" ,
532
534
" reshape_iota<1>" ,
533
535
" broadcast_reduce<1>" ,
534
536
" slice_dot_general<1>" ,
@@ -578,6 +580,14 @@ function optimization_passes(;
578
580
" abs_positive_simplify" ,
579
581
]
580
582
583
+ if reshape_propagate === :up
584
+ append! (transform_passes_list, [" concat_appending_reshape" ])
585
+ elseif reshape_propagate === :down
586
+ append! (transform_passes_list, [" slice_reshape" ])
587
+ else
588
+ error (" Invalid value for reshape_propagate. Must be :up or :down." )
589
+ end
590
+
581
591
if transpose_propagate === :up
582
592
append! (
583
593
transform_passes_list,
@@ -816,6 +826,7 @@ function compile_mlir!(
816
826
shardy_passes:: Symbol = :to_mhlo_shardings , # :none | :to_mhlo_shardings
817
827
no_nan:: Bool = false ,
818
828
transpose_propagate:: Symbol = :up ,
829
+ reshape_propagate:: Symbol = :up ,
819
830
assert_nonallocating:: Bool = false ,
820
831
backend= " gpu" ,
821
832
fn_kwargs= (),
@@ -925,8 +936,12 @@ function compile_mlir!(
925
936
jit = " lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures ()) run_init=true toolkitPath=$toolkit },symbol-dce"
926
937
end
927
938
928
- opt_passes = optimization_passes (; no_nan, sroa= true , transpose_propagate)
929
- opt_passes2 = optimization_passes (; no_nan, sroa= false , transpose_propagate)
939
+ opt_passes = optimization_passes (;
940
+ no_nan, sroa= true , transpose_propagate, reshape_propagate
941
+ )
942
+ opt_passes2 = optimization_passes (;
943
+ no_nan, sroa= false , transpose_propagate, reshape_propagate
944
+ )
930
945
931
946
raise_passes = if raise isa String
932
947
# Raising passes were specified
@@ -1088,6 +1103,18 @@ function compile_mlir!(
1088
1103
error (" Invalid optimize option: $(Meta. quot (optimize)) " )
1089
1104
end
1090
1105
1106
+ if optimize ∉ (:none , :just_batch , :canonicalize ) &&
1107
+ (transpose_propagate == :up || reshape_propagate == :up )
1108
+ # We tried propagating reshapes and transposes up. If at this point we are left with
1109
+ # them, we propagate them down to minimize the number of Ops in the IR.
1110
+ run_pass_pipeline! (
1111
+ mod,
1112
+ optimization_passes (;
1113
+ transpose_propagate:: Symbol = :down , reshape_propagate:: Symbol = :down
1114
+ ),
1115
+ )
1116
+ end
1117
+
1091
1118
# shardy passes
1092
1119
use_shardy_partitioner = false
1093
1120
result_shardings = missing
@@ -1260,6 +1287,7 @@ macro code_hlo(args...)
1260
1287
:assert_nonallocating => false ,
1261
1288
:donated_args => :(:auto ),
1262
1289
:transpose_propagate => :(:up ),
1290
+ :reshape_propagate => :(:up ),
1263
1291
)
1264
1292
compile_expr, (; compiled) = compile_call_expr (
1265
1293
__module__, compile_mlir, default_options, args...
@@ -1291,6 +1319,7 @@ macro code_mhlo(args...)
1291
1319
:assert_nonallocating => false ,
1292
1320
:donated_args => :(:auto ),
1293
1321
:transpose_propagate => :(:up ),
1322
+ :reshape_propagate => :(:up ),
1294
1323
)
1295
1324
compile_expr, (; compiled) = compile_call_expr (
1296
1325
__module__, compile_xla, default_options, args...
@@ -1322,6 +1351,7 @@ macro code_xla(args...)
1322
1351
:assert_nonallocating => false ,
1323
1352
:donated_args => :(:auto ),
1324
1353
:transpose_propagate => :(:up ),
1354
+ :reshape_propagate => :(:up ),
1325
1355
)
1326
1356
compile_expr, (; compiled) = compile_call_expr (
1327
1357
__module__, compile_xla, default_options, args...
@@ -1353,6 +1383,7 @@ macro compile(args...)
1353
1383
:serializable => false ,
1354
1384
:donated_args => :(:auto ),
1355
1385
:transpose_propagate => :(:up ),
1386
+ :reshape_propagate => :(:up ),
1356
1387
)
1357
1388
return esc (first (compile_call_expr (__module__, compile, default_options, args... )))
1358
1389
end
@@ -1373,6 +1404,7 @@ macro jit(args...)
1373
1404
:assert_nonallocating => false ,
1374
1405
:donated_args => :(:auto ),
1375
1406
:transpose_propagate => :(:up ),
1407
+ :reshape_propagate => :(:up ),
1376
1408
)
1377
1409
compile_expr, (; compiled, args) = compile_call_expr (
1378
1410
__module__, compile, default_options, args...
0 commit comments