Skip to content

Commit a3646e0

Browse files
authored
Merge pull request #9 from aryavohra/better-benchmark
Convolution support
2 parents 34d7d92 + 9c8ecbf commit a3646e0

10 files changed

+686
-62
lines changed

src/enzyme_ad/jax/Passes/EqualitySaturation.cpp

+489-37
Large diffs are not rendered by default.

src/enzyme_ad/jax/Passes/EqualitySaturation.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace tensat {
99
enum class Type : uint8_t;
1010
enum class Ops : uint8_t;
1111
struct Vector;
12+
struct Matrix;
1213
struct Tensor;
1314

1415
/**
@@ -17,12 +18,14 @@ struct Tensor;
1718

1819
uint64_t get_cost(Ops op, rust::Vec<tensat::Tensor> operands,
1920
rust::Vec<tensat::Vector> other_vector_args,
20-
rust::Vec<int64_t> int_args);
21+
rust::Vec<int64_t> int_args,
22+
rust::Vec<tensat::Matrix> matrix_args);
2123

2224
mlir::Type newTensorType(mlir::OpBuilder &builder, Tensor tensor);
2325
mlir::Type tensatTypeToMlirType(mlir::OpBuilder &builder, Type type);
2426

2527
rust::Vec<Tensor> get_shape(Ops op, rust::Vec<tensat::Tensor> operands,
2628
rust::Vec<tensat::Vector> other_vector_args,
27-
rust::Vec<int64_t> int_args);
29+
rust::Vec<int64_t> int_args,
30+
rust::Vec<tensat::Matrix> matrix_args);
2831
} // namespace tensat

src/enzyme_ad/jax/deps/tensat/Cargo.Bazel.lock

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ dependencies = [
207207
[[package]]
208208
name = "egg"
209209
version = "0.6.1-dev"
210-
source = "git+https://github.com/yycdavid/egg?rev=12cc1ee7731d37fe91901c81f59678fa1d08a2bb#12cc1ee7731d37fe91901c81f59678fa1d08a2bb"
210+
source = "git+https://github.com/aryavohra/egg?rev=b30d14cff61bff97336323f6eb0978cc7769140d#b30d14cff61bff97336323f6eb0978cc7769140d"
211211
dependencies = [
212212
"indexmap",
213213
"instant",

src/enzyme_ad/jax/deps/tensat/Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/enzyme_ad/jax/deps/tensat/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ serde_json = "1.0"
2525
serde = { version = "1.0", features = ["derive"] }
2626

2727
[dependencies.egg]
28-
git = "https://github.com/yycdavid/egg"
29-
rev = "12cc1ee7731d37fe91901c81f59678fa1d08a2bb"
28+
git = "https://github.com/aryavohra/egg"
29+
rev = "b30d14cff61bff97336323f6eb0978cc7769140d"
3030

3131
[package.metadata.cxx]
3232
library = "c++"

src/enzyme_ad/jax/deps/tensat/converted.txt

+6
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,9 @@
2626
(ConcatenateOp (Vec (MulOp ?x ?y) (MulOp ?z ?w)) ?i)<=>(MulOp (ConcatenateOp (Vec ?x ?z) ?i) (ConcatenateOp (Vec ?y ?w) ?i))
2727

2828
(ConcatenateOp (Vec (ConcatenateOp (Vec ?x ?y) 1) (ConcatenateOp (Vec ?z ?w) 1)) 0)<=>(ConcatenateOp (Vec (ConcatenateOp (Vec ?x ?z) 0) (ConcatenateOp (Vec ?y ?w) 0)) 1)
29+
30+
(ConvolutionOp (MulOp ?x ?w) ?y ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)<=>(ConvolutionOp ?x (MulOp ?y ?w) ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)
31+
32+
(ConvolutionOp ?lhs (ConcatenateOp (Vec ?x ?y) ?i) ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)<=>(ConcatenateOp (Vec (ConvolutionOp ?lhs ?x ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig) (ConvolutionOp ?lhs ?y ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)) ?i)
33+
34+
(ConvolutionOp ?lhs (MulOp ?rhs ?w) ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)<=>(MulOp (ConvolutionOp ?lhs ?rhs ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig) ?w)
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
use crate::{
22
input::ffi,
33
model::*,
4-
rewrites::{get_num_option, get_vec_of_nums_option, get_vec_option},
4+
rewrites::{get_matrix_option, get_num_option, get_vec_of_nums_option, get_vec_option},
55
};
66
use egg::*;
77

88
fn process_enode_args(
99
egraph: &EGraph<Mdl, TensorAnalysis>,
1010
enode: &Mdl,
11-
) -> (Vec<ffi::Tensor>, Vec<ffi::Vector>, Vec<i64>) {
11+
) -> (
12+
Vec<ffi::Tensor>,
13+
Vec<ffi::Vector>,
14+
Vec<i64>,
15+
Vec<ffi::Matrix>,
16+
) {
1217
let mut args: Vec<ffi::Tensor> = vec![];
1318
let mut other_vecs: Vec<ffi::Vector> = vec![];
1419
let mut int_args: Vec<i64> = vec![];
20+
let mut matrix_args: Vec<ffi::Matrix> = vec![];
1521

1622
for child in enode.children().iter() {
1723
if let Some(other_vec) = get_vec_of_nums_option(egraph, &egraph[*child]) {
1824
other_vecs.push(other_vec)
25+
} else if let Some(mat) = get_matrix_option(egraph, &egraph[*child]) {
26+
println!("{:?}", mat);
27+
matrix_args.push(mat)
1928
} else if let Some(vec) = get_vec_option(&egraph[*child]) {
2029
vec.iter()
2130
.for_each(|&id| args.push(egraph[id].data.tensors[0].clone()))
@@ -27,7 +36,7 @@ fn process_enode_args(
2736
}
2837
}
2938

30-
(args, other_vecs, int_args)
39+
(args, other_vecs, int_args, matrix_args)
3140
}
3241

3342
pub fn create_stablehlo_op<F, R>(
@@ -36,10 +45,10 @@ pub fn create_stablehlo_op<F, R>(
3645
process_output: F,
3746
) -> R
3847
where
39-
F: Fn(ffi::Ops, Vec<ffi::Tensor>, Vec<ffi::Vector>, Vec<i64>) -> R,
48+
F: Fn(ffi::Ops, Vec<ffi::Tensor>, Vec<ffi::Vector>, Vec<i64>, Vec<ffi::Matrix>) -> R,
4049
{
4150
let op = ffi::Ops::from_mdl(enode);
42-
let (args, other_vecs, int_args) = process_enode_args(egraph, enode);
43-
let res = process_output(op, args, other_vecs, int_args);
51+
let (args, other_vecs, int_args, matrix_args) = process_enode_args(egraph, enode);
52+
let res = process_output(op, args, other_vecs, int_args, matrix_args);
4453
res
4554
}

src/enzyme_ad/jax/deps/tensat/src/input.rs

+116-8
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub mod ffi {
3434
SelectOp,
3535
ConcatenateOp,
3636
DotGeneralOp,
37+
ConvolutionOp,
3738
PadOp,
3839
SliceOp,
3940
TransposeOp,
@@ -83,6 +84,12 @@ pub mod ffi {
8384
pub vec: Vec<i64>,
8485
}
8586

87+
// Similarly, we're creating a Matrix type for vecs of vecs (padding)
88+
#[derive(Debug)]
89+
struct Matrix {
90+
pub mat: Vec<Vector>,
91+
}
92+
8693
// take floats from c++ and wrap them into f32s below
8794
extern "Rust" {
8895
type Mdl;
@@ -159,6 +166,29 @@ pub mod ffi {
159166
dimension: i64,
160167
output: Tensor,
161168
) -> Box<TensorInfo>;
169+
fn new_convolution_op(
170+
self: &mut CppGraphConverter,
171+
lhs: &TensorInfo,
172+
rhs: &TensorInfo,
173+
windowStrides: Vec<i64>,
174+
padding: Vec<Vector>,
175+
lhsDilation: Vec<i64>,
176+
rhsDilation: Vec<i64>,
177+
windowReversal: Vec<bool>,
178+
inputBatchDimension: i64,
179+
inputFeatureDimension: i64,
180+
inputSpatialDimension: Vec<i64>,
181+
kernelInputFeatureDimension: i64,
182+
kernelOutputFeatureDimension: i64,
183+
kernelSpatialDimension: Vec<i64>,
184+
outputBatchDimension: i64,
185+
outputFeatureDimension: i64,
186+
outputSpatialDimension: Vec<i64>,
187+
featureGroupCount: i64,
188+
batchGroupCount: i64,
189+
precision_config: Vec<i64>,
190+
output: Tensor,
191+
) -> Box<TensorInfo>;
162192
fn new_dot_general_op(
163193
self: &mut CppGraphConverter,
164194
lhs: &TensorInfo,
@@ -274,7 +304,7 @@ pub mod ffi {
274304
fn new_blackbox_op(
275305
self: &mut CppGraphConverter,
276306
inpts: &[*mut TensorInfo],
277-
captured: &[*mut TensorInfo], // values that appear in a block that was declared outside
307+
captured: &[*mut TensorInfo], // values that appear in a block that was declared outside
278308
cpp_num: i64,
279309
outputs: &Vec<Tensor>,
280310
) -> Box<TensorInfo>;
@@ -293,6 +323,7 @@ pub mod ffi {
293323
operands: Vec<Tensor>,
294324
other_vector_args: Vec<Vector>,
295325
int_args: Vec<i64>,
326+
matrix_args: Vec<Matrix>,
296327
) -> u64;
297328
}
298329

@@ -304,6 +335,7 @@ pub mod ffi {
304335
operands: Vec<Tensor>,
305336
other_vector_args: Vec<Vector>,
306337
int_args: Vec<i64>,
338+
matrix_args: Vec<Matrix>,
307339
) -> Vec<Tensor>;
308340
}
309341
}
@@ -356,6 +388,7 @@ impl ffi::Ops {
356388
Mdl::PadOp(_) => Ops::PadOp,
357389
Mdl::SliceOp(_) => Ops::SliceOp,
358390
Mdl::TransposeOp(_) => Ops::TransposeOp,
391+
Mdl::ConvolutionOp(_) => Ops::ConvolutionOp,
359392
Mdl::MulOp(_) => Ops::MulOp,
360393
Mdl::AddOp(_) => Ops::AddOp,
361394
Mdl::DivOp(_) => Ops::DivOp,
@@ -601,10 +634,7 @@ impl CppGraphConverter {
601634
Box::new(res)
602635
}
603636

604-
fn new_tensorinfo_vec(
605-
&mut self,
606-
inputs: &[*mut TensorInfo]
607-
) -> Id {
637+
fn new_tensorinfo_vec(&mut self, inputs: &[*mut TensorInfo]) -> Id {
608638
let tensor_infos: Vec<&TensorInfo> = inputs.iter().map(|&ptr| unsafe { &*ptr }).collect();
609639
let inputs_node = Mdl::Vec(tensor_infos.iter().map(|i| i.id).collect());
610640
self.rec_expr.add(inputs_node)
@@ -630,6 +660,79 @@ impl CppGraphConverter {
630660
Box::new(res)
631661
}
632662

663+
pub fn new_convolution_op(
664+
&mut self,
665+
lhs: &TensorInfo,
666+
rhs: &TensorInfo,
667+
window_strides: Vec<i64>,
668+
padding: Vec<ffi::Vector>,
669+
lhs_dilation: Vec<i64>,
670+
rhs_dilation: Vec<i64>,
671+
window_reversal: Vec<bool>,
672+
input_batch_dimension: i64,
673+
input_feature_dimension: i64,
674+
input_spatial_dimensions: Vec<i64>,
675+
kernel_input_feature_dimension: i64,
676+
kernel_output_feature_dimension: i64,
677+
kernel_spatial_dimensions: Vec<i64>,
678+
output_batch_dimension: i64,
679+
output_feature_dimension: i64,
680+
output_spatial_dimensions: Vec<i64>,
681+
feature_group_count: i64,
682+
batch_group_count: i64,
683+
precision_config: Vec<i64>,
684+
output: ffi::Tensor,
685+
) -> Box<TensorInfo> {
686+
let window_strides_node_id = self.vec_node(window_strides);
687+
let lhs_dilation_node_id = self.vec_node(lhs_dilation);
688+
let rhs_dilation_node_id = self.vec_node(rhs_dilation);
689+
690+
// We could add a bool element type vec?
691+
let window_reversal_node_id =
692+
self.vec_node(window_reversal.iter().map(|x| *x as i64).collect());
693+
let input_spatial_dimensions_node_id = self.vec_node(input_spatial_dimensions);
694+
let kernel_spatial_dimensions_node_id = self.vec_node(kernel_spatial_dimensions);
695+
let output_spatial_dimensions_node_id = self.vec_node(output_spatial_dimensions);
696+
let precision_config_node_id = self.vec_node(precision_config);
697+
698+
let padding_node_ids: Vec<Id> = padding
699+
.into_iter()
700+
.map(|pad| self.vec_node(pad.vec))
701+
.collect::<Vec<Id>>();
702+
let padding_node_id = self.rec_expr.add(Mdl::Vec(padding_node_ids));
703+
704+
let new_node = Mdl::ConvolutionOp([
705+
lhs.id,
706+
rhs.id,
707+
window_strides_node_id,
708+
padding_node_id,
709+
lhs_dilation_node_id,
710+
rhs_dilation_node_id,
711+
window_reversal_node_id,
712+
self.add_or_get_val(input_batch_dimension),
713+
self.add_or_get_val(input_feature_dimension),
714+
input_spatial_dimensions_node_id,
715+
self.add_or_get_val(kernel_input_feature_dimension),
716+
self.add_or_get_val(kernel_output_feature_dimension),
717+
kernel_spatial_dimensions_node_id,
718+
self.add_or_get_val(output_batch_dimension),
719+
self.add_or_get_val(output_feature_dimension),
720+
output_spatial_dimensions_node_id,
721+
self.add_or_get_val(feature_group_count),
722+
self.add_or_get_val(batch_group_count),
723+
precision_config_node_id,
724+
]);
725+
726+
let res = TensorInfo {
727+
id: self.rec_expr.add(new_node),
728+
tensor_data: TensorData {
729+
tensors: vec![output],
730+
name: None,
731+
},
732+
};
733+
Box::new(res)
734+
}
735+
633736
pub fn new_dot_general_op(
634737
&mut self,
635738
lhs: &TensorInfo,
@@ -1043,6 +1146,7 @@ impl CppGraphConverter {
10431146
Mdl::DotGeneralOp(ops) => new_node(ops),
10441147
Mdl::SliceOp(ops) => new_node(ops),
10451148
Mdl::TransposeOp(ops) => new_node(ops),
1149+
Mdl::ConvolutionOp(ops) => new_node(ops),
10461150
Mdl::MulOp(ops) => new_node(ops),
10471151
Mdl::AddOp(ops) => new_node(ops),
10481152
Mdl::DivOp(ops) => new_node(ops),
@@ -1059,7 +1163,7 @@ impl CppGraphConverter {
10591163
Mdl::SSplit0(ops) => new_node(ops),
10601164
Mdl::SSplit1(ops) => new_node(ops),
10611165
Mdl::MatchRank(ops) => new_node(ops),
1062-
_ => unimplemented!()
1166+
_ => unimplemented!(),
10631167
};
10641168

10651169
res.push(node);
@@ -1088,7 +1192,8 @@ impl CppGraphConverter {
10881192
read_to_string(rule_file).expect("Something went wrong reading the rule file");
10891193
let time_limit_sec = Duration::new(n_sec, 0);
10901194
let pre_defined_rules = PRE_DEFINED_RULES.iter().map(|&x| x);
1091-
let split_rules: Vec<&str> = learned_rules.split("\n")
1195+
let split_rules: Vec<&str> = learned_rules
1196+
.split("\n")
10921197
.filter(|x| !x.is_empty())
10931198
.chain(pre_defined_rules)
10941199
.collect();
@@ -1234,7 +1339,10 @@ fn extract_by_ilp(
12341339
let class_constraint = true;
12351340
let no_order = true;
12361341
let initialise_with_greedy = false;
1237-
let fusion_costs: bool = std::env::var("FUSION_COSTS").unwrap_or(String::from("true")).parse().unwrap();
1342+
let fusion_costs: bool = std::env::var("FUSION_COSTS")
1343+
.unwrap_or(String::from("true"))
1344+
.parse()
1345+
.unwrap();
12381346
let mut arg_vec = vec!["src/enzyme_ad/jax/deps/tensat/extractor/extract.py"];
12391347
if order_var_int {
12401348
arg_vec.push("--order_var_int");

src/enzyme_ad/jax/deps/tensat/src/model.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ define_language! {
2626
"GatherOp" = GatherOp([Id; 10]),
2727
"SelectOp" = SelectOp([Id; 3]), // pred, on_true, on_false
2828
"ConcatenateOp" = ConcatenateOp([Id; 2]), // inputs, dimension
29+
"ConvolutionOp" = ConvolutionOp([Id; 19]), // LOTS of inputs
2930
"DotGeneralOp" = DotGeneralOp([Id; 7]), // lhs, rhs, ..., shape
3031
"PadOp" = PadOp([Id; 5]), // input, padding_value, edge_padding_low,
3132
// edge_padding_high, interior_padding
@@ -50,7 +51,7 @@ define_language! {
5051
// Complete pain, has arity 12
5152
"ScatterOp" = ScatterOp([Id; 4]), // input, scatter_indices, updates, dimension_numbers
5253
"ReturnOp" = ReturnOp([Id; 1]),
53-
"BlackBox" = BlackBox([Id; 3]), // id, args, captured values (last two should be vecs)
54+
"BlackBox" = BlackBox([Id; 3]), // id, args, captured values (last two should be vecs)
5455
"Vec" = Vec(Vec<Id>),
5556
"Index" = Index([Id; 2]), // index, input. for indexing into ops with multiple result Values.
5657
// SHORTHANDS (not 1:1 with stablehlo)

0 commit comments

Comments
 (0)