@@ -1036,21 +1036,25 @@ function raising!(f, is_raising::Bool)
1036
1036
end
1037
1037
end
1038
1038
1039
- # TODO investigate which options need enable/disable
1040
- # const comm_pass = "optimize-communication{periodic_concat=1 rotate_comm=1 wrap_comm=1 dus_to_pad_manual_comp_comm=1 dus_to_pad_comm=0 concat_two_operands_comm=0 concat_to_pad_comm=0 extend_to_pad_comm=0 wrap_to_pad_comm=0 concat_two_dus_like=1 extend_dus_like=1}"
1041
- const comm_pass = " optimize-communication"
1042
-
1043
- const optimize_comms_passes = (
1044
- # rotate handler presently broken (and handled okay presently), disabling for now
1045
- " enzyme-hlo-generate-td{patterns=lower_rotate;concat_to_onedim_dus;concat_to_onedim_dusslice}" ,
1046
- " transform-interpreter" ,
1047
- " enzyme-hlo-remove-transform" ,
1048
- comm_pass,
1049
- " enzyme-hlo-generate-td{patterns=lower_rotate;lower_wrap;lower_extend}" ,
1050
- " transform-interpreter" ,
1051
- " enzyme-hlo-remove-transform" ,
1052
- comm_pass,
1053
- )
1039
+ function get_optimize_comms_passes (options:: Bool )
1040
+ options || return String[]
1041
+ return get_optimize_comms_passes (Reactant. OptimizeCommunicationOptions ())
1042
+ end
1043
+
1044
+ function get_optimize_comms_passes (options:: Reactant.OptimizeCommunicationOptions )
1045
+ options_str = String (options)
1046
+ res = [
1047
+ " enzyme-hlo-generate-td{patterns=lower_rotate;concat_to_onedim_dus;concat_to_onedim_dusslice;concatreshape_to_onedim_dus}" ,
1048
+ " transform-interpreter" ,
1049
+ " enzyme-hlo-remove-transform" ,
1050
+ options_str,
1051
+ " enzyme-hlo-generate-td{patterns=lower_rotate;lower_wrap;lower_extend}" ,
1052
+ " transform-interpreter" ,
1053
+ " enzyme-hlo-remove-transform" ,
1054
+ options_str,
1055
+ ]
1056
+ return res
1057
+ end
1054
1058
1055
1059
function compile_mlir! (
1056
1060
mod,
@@ -1065,6 +1069,7 @@ function compile_mlir!(
1065
1069
no_nan:: Bool = false ,
1066
1070
transpose_propagate:: Symbol = :up ,
1067
1071
reshape_propagate:: Symbol = :up ,
1072
+ optimize_communications:: Bool = true ,
1068
1073
assert_nonallocating:: Bool = false ,
1069
1074
backend= " gpu" ,
1070
1075
raise:: Union{Bool,String} = false ,
@@ -1459,7 +1464,7 @@ function compile_mlir!(
1459
1464
sym_visibility= MLIR. IR. attr (compiled_f, " private" ),
1460
1465
)
1461
1466
fnbody = MLIR. IR. Block (
1462
- in_tys_padded, [MLIR. IR. Location () for _ in in_tys_padded ]
1467
+ in_tys_padded, [MLIR. IR. Location (MLIR . API . mlirValueGetLocation (MLIR . IR . argument (MLIR . IR . first_block (MLIR . IR . region (compiled_f, 1 )), i))) for i in 1 : length (linear_args) ]
1463
1468
)
1464
1469
push! (MLIR. IR. region (func_with_padding, 1 ), fnbody)
1465
1470
MLIR. IR. activate! (fnbody)
@@ -1592,7 +1597,7 @@ function compile_mlir!(
1592
1597
join (
1593
1598
[
1594
1599
" sdy-close-shardings" ,
1595
- optimize_comms_passes ... ,
1600
+ get_optimize_comms_passes (optimize_communications) ... ,
1596
1601
" xla-sdy-stablehlo-export-pipeline" ,
1597
1602
],
1598
1603
" ," ,
@@ -1606,7 +1611,7 @@ function compile_mlir!(
1606
1611
[
1607
1612
" sdy-propagation-pipeline" ,
1608
1613
" sdy-close-shardings" ,
1609
- optimize_comms_passes ... ,
1614
+ get_optimize_comms_passes (optimize_communications) ... ,
1610
1615
" xla-sdy-stablehlo-export-pipeline" ,
1611
1616
],
1612
1617
" ," ,
@@ -1764,6 +1769,7 @@ macro code_hlo(args...)
1764
1769
:transpose_propagate => :(:up ),
1765
1770
:reshape_propagate => :(:up ),
1766
1771
:optimize_then_pad => true ,
1772
+ :optimize_communications => true ,
1767
1773
)
1768
1774
compile_expr, (; compiled) = compile_call_expr (
1769
1775
__module__, compile_mlir, default_options, args...
@@ -1797,6 +1803,7 @@ macro code_mhlo(args...)
1797
1803
:transpose_propagate => :(:up ),
1798
1804
:reshape_propagate => :(:up ),
1799
1805
:optimize_then_pad => true ,
1806
+ :optimize_communications => true ,
1800
1807
)
1801
1808
compile_expr, (; compiled) = compile_call_expr (
1802
1809
__module__, compile_xla, default_options, args...
@@ -1830,6 +1837,7 @@ macro code_xla(args...)
1830
1837
:transpose_propagate => :(:up ),
1831
1838
:reshape_propagate => :(:up ),
1832
1839
:optimize_then_pad => true ,
1840
+ :optimize_communications => true ,
1833
1841
)
1834
1842
compile_expr, (; compiled) = compile_call_expr (
1835
1843
__module__, compile_xla, default_options, args...
@@ -1863,6 +1871,7 @@ macro compile(args...)
1863
1871
:transpose_propagate => :(:up ),
1864
1872
:reshape_propagate => :(:up ),
1865
1873
:optimize_then_pad => true ,
1874
+ :optimize_communications => true ,
1866
1875
)
1867
1876
return esc (first (compile_call_expr (__module__, compile, default_options, args... )))
1868
1877
end
@@ -1885,6 +1894,7 @@ macro jit(args...)
1885
1894
:transpose_propagate => :(:up ),
1886
1895
:reshape_propagate => :(:up ),
1887
1896
:optimize_then_pad => true ,
1897
+ :optimize_communications => true ,
1888
1898
)
1889
1899
compile_expr, (; compiled, args) = compile_call_expr (
1890
1900
__module__, compile, default_options, args...
0 commit comments