@@ -34,6 +34,7 @@ pub mod ffi {
34
34
SelectOp ,
35
35
ConcatenateOp ,
36
36
DotGeneralOp ,
37
+ ConvolutionOp ,
37
38
PadOp ,
38
39
SliceOp ,
39
40
TransposeOp ,
@@ -83,6 +84,12 @@ pub mod ffi {
83
84
pub vec : Vec < i64 > ,
84
85
}
85
86
87
+ // Similarly, we're creating a Matrix type for vecs of vecs (padding)
88
+ #[ derive( Debug ) ]
89
+ struct Matrix {
90
+ pub mat : Vec < Vector > ,
91
+ }
92
+
86
93
// take floats from c++ and wrap them into f32s below
87
94
extern "Rust" {
88
95
type Mdl ;
@@ -159,6 +166,29 @@ pub mod ffi {
159
166
dimension : i64 ,
160
167
output : Tensor ,
161
168
) -> Box < TensorInfo > ;
169
+ fn new_convolution_op (
170
+ self : & mut CppGraphConverter ,
171
+ lhs : & TensorInfo ,
172
+ rhs : & TensorInfo ,
173
+ windowStrides : Vec < i64 > ,
174
+ padding : Vec < Vector > ,
175
+ lhsDilation : Vec < i64 > ,
176
+ rhsDilation : Vec < i64 > ,
177
+ windowReversal : Vec < bool > ,
178
+ inputBatchDimension : i64 ,
179
+ inputFeatureDimension : i64 ,
180
+ inputSpatialDimension : Vec < i64 > ,
181
+ kernelInputFeatureDimension : i64 ,
182
+ kernelOutputFeatureDimension : i64 ,
183
+ kernelSpatialDimension : Vec < i64 > ,
184
+ outputBatchDimension : i64 ,
185
+ outputFeatureDimension : i64 ,
186
+ outputSpatialDimension : Vec < i64 > ,
187
+ featureGroupCount : i64 ,
188
+ batchGroupCount : i64 ,
189
+ precision_config : Vec < i64 > ,
190
+ output : Tensor ,
191
+ ) -> Box < TensorInfo > ;
162
192
fn new_dot_general_op (
163
193
self : & mut CppGraphConverter ,
164
194
lhs : & TensorInfo ,
@@ -274,7 +304,7 @@ pub mod ffi {
274
304
fn new_blackbox_op (
275
305
self : & mut CppGraphConverter ,
276
306
inpts : & [ * mut TensorInfo ] ,
277
- captured : & [ * mut TensorInfo ] , // values that appear in a block that was declared outside
307
+ captured : & [ * mut TensorInfo ] , // values that appear in a block that was declared outside
278
308
cpp_num : i64 ,
279
309
outputs : & Vec < Tensor > ,
280
310
) -> Box < TensorInfo > ;
@@ -293,6 +323,7 @@ pub mod ffi {
293
323
operands : Vec < Tensor > ,
294
324
other_vector_args : Vec < Vector > ,
295
325
int_args : Vec < i64 > ,
326
+ matrix_args : Vec < Matrix > ,
296
327
) -> u64 ;
297
328
}
298
329
@@ -304,6 +335,7 @@ pub mod ffi {
304
335
operands : Vec < Tensor > ,
305
336
other_vector_args : Vec < Vector > ,
306
337
int_args : Vec < i64 > ,
338
+ matrix_args : Vec < Matrix > ,
307
339
) -> Vec < Tensor > ;
308
340
}
309
341
}
@@ -356,6 +388,7 @@ impl ffi::Ops {
356
388
Mdl :: PadOp ( _) => Ops :: PadOp ,
357
389
Mdl :: SliceOp ( _) => Ops :: SliceOp ,
358
390
Mdl :: TransposeOp ( _) => Ops :: TransposeOp ,
391
+ Mdl :: ConvolutionOp ( _) => Ops :: ConvolutionOp ,
359
392
Mdl :: MulOp ( _) => Ops :: MulOp ,
360
393
Mdl :: AddOp ( _) => Ops :: AddOp ,
361
394
Mdl :: DivOp ( _) => Ops :: DivOp ,
@@ -601,10 +634,7 @@ impl CppGraphConverter {
601
634
Box :: new ( res)
602
635
}
603
636
604
- fn new_tensorinfo_vec (
605
- & mut self ,
606
- inputs : & [ * mut TensorInfo ]
607
- ) -> Id {
637
+ fn new_tensorinfo_vec ( & mut self , inputs : & [ * mut TensorInfo ] ) -> Id {
608
638
let tensor_infos: Vec < & TensorInfo > = inputs. iter ( ) . map ( |& ptr| unsafe { & * ptr } ) . collect ( ) ;
609
639
let inputs_node = Mdl :: Vec ( tensor_infos. iter ( ) . map ( |i| i. id ) . collect ( ) ) ;
610
640
self . rec_expr . add ( inputs_node)
@@ -630,6 +660,79 @@ impl CppGraphConverter {
630
660
Box :: new ( res)
631
661
}
632
662
663
+ pub fn new_convolution_op (
664
+ & mut self ,
665
+ lhs : & TensorInfo ,
666
+ rhs : & TensorInfo ,
667
+ window_strides : Vec < i64 > ,
668
+ padding : Vec < ffi:: Vector > ,
669
+ lhs_dilation : Vec < i64 > ,
670
+ rhs_dilation : Vec < i64 > ,
671
+ window_reversal : Vec < bool > ,
672
+ input_batch_dimension : i64 ,
673
+ input_feature_dimension : i64 ,
674
+ input_spatial_dimensions : Vec < i64 > ,
675
+ kernel_input_feature_dimension : i64 ,
676
+ kernel_output_feature_dimension : i64 ,
677
+ kernel_spatial_dimensions : Vec < i64 > ,
678
+ output_batch_dimension : i64 ,
679
+ output_feature_dimension : i64 ,
680
+ output_spatial_dimensions : Vec < i64 > ,
681
+ feature_group_count : i64 ,
682
+ batch_group_count : i64 ,
683
+ precision_config : Vec < i64 > ,
684
+ output : ffi:: Tensor ,
685
+ ) -> Box < TensorInfo > {
686
+ let window_strides_node_id = self . vec_node ( window_strides) ;
687
+ let lhs_dilation_node_id = self . vec_node ( lhs_dilation) ;
688
+ let rhs_dilation_node_id = self . vec_node ( rhs_dilation) ;
689
+
690
+ // We could add a bool element type vec?
691
+ let window_reversal_node_id =
692
+ self . vec_node ( window_reversal. iter ( ) . map ( |x| * x as i64 ) . collect ( ) ) ;
693
+ let input_spatial_dimensions_node_id = self . vec_node ( input_spatial_dimensions) ;
694
+ let kernel_spatial_dimensions_node_id = self . vec_node ( kernel_spatial_dimensions) ;
695
+ let output_spatial_dimensions_node_id = self . vec_node ( output_spatial_dimensions) ;
696
+ let precision_config_node_id = self . vec_node ( precision_config) ;
697
+
698
+ let padding_node_ids: Vec < Id > = padding
699
+ . into_iter ( )
700
+ . map ( |pad| self . vec_node ( pad. vec ) )
701
+ . collect :: < Vec < Id > > ( ) ;
702
+ let padding_node_id = self . rec_expr . add ( Mdl :: Vec ( padding_node_ids) ) ;
703
+
704
+ let new_node = Mdl :: ConvolutionOp ( [
705
+ lhs. id ,
706
+ rhs. id ,
707
+ window_strides_node_id,
708
+ padding_node_id,
709
+ lhs_dilation_node_id,
710
+ rhs_dilation_node_id,
711
+ window_reversal_node_id,
712
+ self . add_or_get_val ( input_batch_dimension) ,
713
+ self . add_or_get_val ( input_feature_dimension) ,
714
+ input_spatial_dimensions_node_id,
715
+ self . add_or_get_val ( kernel_input_feature_dimension) ,
716
+ self . add_or_get_val ( kernel_output_feature_dimension) ,
717
+ kernel_spatial_dimensions_node_id,
718
+ self . add_or_get_val ( output_batch_dimension) ,
719
+ self . add_or_get_val ( output_feature_dimension) ,
720
+ output_spatial_dimensions_node_id,
721
+ self . add_or_get_val ( feature_group_count) ,
722
+ self . add_or_get_val ( batch_group_count) ,
723
+ precision_config_node_id,
724
+ ] ) ;
725
+
726
+ let res = TensorInfo {
727
+ id : self . rec_expr . add ( new_node) ,
728
+ tensor_data : TensorData {
729
+ tensors : vec ! [ output] ,
730
+ name : None ,
731
+ } ,
732
+ } ;
733
+ Box :: new ( res)
734
+ }
735
+
633
736
pub fn new_dot_general_op (
634
737
& mut self ,
635
738
lhs : & TensorInfo ,
@@ -1043,6 +1146,7 @@ impl CppGraphConverter {
1043
1146
Mdl :: DotGeneralOp ( ops) => new_node ( ops) ,
1044
1147
Mdl :: SliceOp ( ops) => new_node ( ops) ,
1045
1148
Mdl :: TransposeOp ( ops) => new_node ( ops) ,
1149
+ Mdl :: ConvolutionOp ( ops) => new_node ( ops) ,
1046
1150
Mdl :: MulOp ( ops) => new_node ( ops) ,
1047
1151
Mdl :: AddOp ( ops) => new_node ( ops) ,
1048
1152
Mdl :: DivOp ( ops) => new_node ( ops) ,
@@ -1059,7 +1163,7 @@ impl CppGraphConverter {
1059
1163
Mdl :: SSplit0 ( ops) => new_node ( ops) ,
1060
1164
Mdl :: SSplit1 ( ops) => new_node ( ops) ,
1061
1165
Mdl :: MatchRank ( ops) => new_node ( ops) ,
1062
- _ => unimplemented ! ( )
1166
+ _ => unimplemented ! ( ) ,
1063
1167
} ;
1064
1168
1065
1169
res. push ( node) ;
@@ -1088,7 +1192,8 @@ impl CppGraphConverter {
1088
1192
read_to_string ( rule_file) . expect ( "Something went wrong reading the rule file" ) ;
1089
1193
let time_limit_sec = Duration :: new ( n_sec, 0 ) ;
1090
1194
let pre_defined_rules = PRE_DEFINED_RULES . iter ( ) . map ( |& x| x) ;
1091
- let split_rules: Vec < & str > = learned_rules. split ( "\n " )
1195
+ let split_rules: Vec < & str > = learned_rules
1196
+ . split ( "\n " )
1092
1197
. filter ( |x| !x. is_empty ( ) )
1093
1198
. chain ( pre_defined_rules)
1094
1199
. collect ( ) ;
@@ -1234,7 +1339,10 @@ fn extract_by_ilp(
1234
1339
let class_constraint = true ;
1235
1340
let no_order = true ;
1236
1341
let initialise_with_greedy = false ;
1237
- let fusion_costs: bool = std:: env:: var ( "FUSION_COSTS" ) . unwrap_or ( String :: from ( "true" ) ) . parse ( ) . unwrap ( ) ;
1342
+ let fusion_costs: bool = std:: env:: var ( "FUSION_COSTS" )
1343
+ . unwrap_or ( String :: from ( "true" ) )
1344
+ . parse ( )
1345
+ . unwrap ( ) ;
1238
1346
let mut arg_vec = vec ! [ "src/enzyme_ad/jax/deps/tensat/extractor/extract.py" ] ;
1239
1347
if order_var_int {
1240
1348
arg_vec. push ( "--order_var_int" ) ;
0 commit comments