Skip to content

Commit 777a318

Browse files
authored
Refactor substrait producer into multiple files (#16089)
1 parent e3e7d50 commit 777a318

26 files changed

+3683
-3000
lines changed

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 0 additions & 3000 deletions
This file was deleted.
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::logical_plan::producer::SubstraitProducer;
19+
use datafusion::common::DFSchemaRef;
20+
use datafusion::logical_expr::expr;
21+
use datafusion::logical_expr::expr::AggregateFunctionParams;
22+
use substrait::proto::aggregate_function::AggregationInvocation;
23+
use substrait::proto::aggregate_rel::Measure;
24+
use substrait::proto::function_argument::ArgType;
25+
use substrait::proto::sort_field::{SortDirection, SortKind};
26+
use substrait::proto::{
27+
AggregateFunction, AggregationPhase, FunctionArgument, SortField,
28+
};
29+
30+
pub fn from_aggregate_function(
31+
producer: &mut impl SubstraitProducer,
32+
agg_fn: &expr::AggregateFunction,
33+
schema: &DFSchemaRef,
34+
) -> datafusion::common::Result<Measure> {
35+
let expr::AggregateFunction {
36+
func,
37+
params:
38+
AggregateFunctionParams {
39+
args,
40+
distinct,
41+
filter,
42+
order_by,
43+
null_treatment: _null_treatment,
44+
},
45+
} = agg_fn;
46+
let sorts = if let Some(order_by) = order_by {
47+
order_by
48+
.iter()
49+
.map(|expr| to_substrait_sort_field(producer, expr, schema))
50+
.collect::<datafusion::common::Result<Vec<_>>>()?
51+
} else {
52+
vec![]
53+
};
54+
let mut arguments: Vec<FunctionArgument> = vec![];
55+
for arg in args {
56+
arguments.push(FunctionArgument {
57+
arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)),
58+
});
59+
}
60+
let function_anchor = producer.register_function(func.name().to_string());
61+
#[allow(deprecated)]
62+
Ok(Measure {
63+
measure: Some(AggregateFunction {
64+
function_reference: function_anchor,
65+
arguments,
66+
sorts,
67+
output_type: None,
68+
invocation: match distinct {
69+
true => AggregationInvocation::Distinct as i32,
70+
false => AggregationInvocation::All as i32,
71+
},
72+
phase: AggregationPhase::Unspecified as i32,
73+
args: vec![],
74+
options: vec![],
75+
}),
76+
filter: match filter {
77+
Some(f) => Some(producer.handle_expr(f, schema)?),
78+
None => None,
79+
},
80+
})
81+
}
82+
83+
/// Converts sort expression to corresponding substrait `SortField`
84+
fn to_substrait_sort_field(
85+
producer: &mut impl SubstraitProducer,
86+
sort: &expr::Sort,
87+
schema: &DFSchemaRef,
88+
) -> datafusion::common::Result<SortField> {
89+
let sort_kind = match (sort.asc, sort.nulls_first) {
90+
(true, true) => SortDirection::AscNullsFirst,
91+
(true, false) => SortDirection::AscNullsLast,
92+
(false, true) => SortDirection::DescNullsFirst,
93+
(false, false) => SortDirection::DescNullsLast,
94+
};
95+
Ok(SortField {
96+
expr: Some(producer.handle_expr(&sort.expr, schema)?),
97+
sort_kind: Some(SortKind::Direction(sort_kind.into())),
98+
})
99+
}
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer};
19+
use crate::variation_const::DEFAULT_TYPE_VARIATION_REF;
20+
use datafusion::common::{DFSchemaRef, ScalarValue};
21+
use datafusion::logical_expr::{Cast, Expr, TryCast};
22+
use substrait::proto::expression::cast::FailureBehavior;
23+
use substrait::proto::expression::literal::LiteralType;
24+
use substrait::proto::expression::{Literal, RexType};
25+
use substrait::proto::Expression;
26+
27+
pub fn from_cast(
28+
producer: &mut impl SubstraitProducer,
29+
cast: &Cast,
30+
schema: &DFSchemaRef,
31+
) -> datafusion::common::Result<Expression> {
32+
let Cast { expr, data_type } = cast;
33+
// since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null
34+
if let Expr::Literal(lit) = expr.as_ref() {
35+
// only the untyped(a null scalar value) null literal need this special handling
36+
// since all other kind of nulls are already typed and can be handled by substrait
37+
// e.g. null::<Int32Type> or null::<Utf8Type>
38+
if matches!(lit, ScalarValue::Null) {
39+
let lit = Literal {
40+
nullable: true,
41+
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
42+
literal_type: Some(LiteralType::Null(to_substrait_type(
43+
data_type, true,
44+
)?)),
45+
};
46+
return Ok(Expression {
47+
rex_type: Some(RexType::Literal(lit)),
48+
});
49+
}
50+
}
51+
Ok(Expression {
52+
rex_type: Some(RexType::Cast(Box::new(
53+
substrait::proto::expression::Cast {
54+
r#type: Some(to_substrait_type(data_type, true)?),
55+
input: Some(Box::new(producer.handle_expr(expr, schema)?)),
56+
failure_behavior: FailureBehavior::ThrowException.into(),
57+
},
58+
))),
59+
})
60+
}
61+
62+
pub fn from_try_cast(
63+
producer: &mut impl SubstraitProducer,
64+
cast: &TryCast,
65+
schema: &DFSchemaRef,
66+
) -> datafusion::common::Result<Expression> {
67+
let TryCast { expr, data_type } = cast;
68+
Ok(Expression {
69+
rex_type: Some(RexType::Cast(Box::new(
70+
substrait::proto::expression::Cast {
71+
r#type: Some(to_substrait_type(data_type, true)?),
72+
input: Some(Box::new(producer.handle_expr(expr, schema)?)),
73+
failure_behavior: FailureBehavior::ReturnNull.into(),
74+
},
75+
))),
76+
})
77+
}
78+
79+
#[cfg(test)]
80+
mod tests {
81+
use super::*;
82+
use crate::logical_plan::producer::to_substrait_extended_expr;
83+
use datafusion::arrow::datatypes::{DataType, Field};
84+
use datafusion::common::DFSchema;
85+
use datafusion::execution::SessionStateBuilder;
86+
use datafusion::logical_expr::ExprSchemable;
87+
use substrait::proto::expression_reference::ExprType;
88+
89+
#[tokio::test]
90+
async fn fold_cast_null() {
91+
let state = SessionStateBuilder::default().build();
92+
let empty_schema = DFSchemaRef::new(DFSchema::empty());
93+
let field = Field::new("out", DataType::Int32, false);
94+
95+
let expr = Expr::Literal(ScalarValue::Null)
96+
.cast_to(&DataType::Int32, &empty_schema)
97+
.unwrap();
98+
99+
let typed_null =
100+
to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)
101+
.unwrap();
102+
103+
if let ExprType::Expression(expr) =
104+
typed_null.referred_expr[0].expr_type.as_ref().unwrap()
105+
{
106+
let lit = Literal {
107+
nullable: true,
108+
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
109+
literal_type: Some(LiteralType::Null(
110+
to_substrait_type(&DataType::Int32, true).unwrap(),
111+
)),
112+
};
113+
let expected = Expression {
114+
rex_type: Some(RexType::Literal(lit)),
115+
};
116+
assert_eq!(*expr, expected);
117+
} else {
118+
panic!("Expected expression type");
119+
}
120+
121+
// a typed null should not be folded
122+
let expr = Expr::Literal(ScalarValue::Int64(None))
123+
.cast_to(&DataType::Int32, &empty_schema)
124+
.unwrap();
125+
126+
let typed_null =
127+
to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)
128+
.unwrap();
129+
130+
if let ExprType::Expression(expr) =
131+
typed_null.referred_expr[0].expr_type.as_ref().unwrap()
132+
{
133+
let cast_expr = substrait::proto::expression::Cast {
134+
r#type: Some(to_substrait_type(&DataType::Int32, true).unwrap()),
135+
input: Some(Box::new(Expression {
136+
rex_type: Some(RexType::Literal(Literal {
137+
nullable: true,
138+
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
139+
literal_type: Some(LiteralType::Null(
140+
to_substrait_type(&DataType::Int64, true).unwrap(),
141+
)),
142+
})),
143+
})),
144+
failure_behavior: FailureBehavior::ThrowException as i32,
145+
};
146+
let expected = Expression {
147+
rex_type: Some(RexType::Cast(Box::new(cast_expr))),
148+
};
149+
assert_eq!(*expr, expected);
150+
} else {
151+
panic!("Expected expression type");
152+
}
153+
}
154+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::common::{substrait_err, Column, DFSchemaRef};
19+
use datafusion::logical_expr::Expr;
20+
use substrait::proto::expression::field_reference::{
21+
ReferenceType, RootReference, RootType,
22+
};
23+
use substrait::proto::expression::{
24+
reference_segment, FieldReference, ReferenceSegment, RexType,
25+
};
26+
use substrait::proto::Expression;
27+
28+
pub fn from_column(
29+
col: &Column,
30+
schema: &DFSchemaRef,
31+
) -> datafusion::common::Result<Expression> {
32+
let index = schema.index_of_column(col)?;
33+
substrait_field_ref(index)
34+
}
35+
36+
pub(crate) fn substrait_field_ref(
37+
index: usize,
38+
) -> datafusion::common::Result<Expression> {
39+
Ok(Expression {
40+
rex_type: Some(RexType::Selection(Box::new(FieldReference {
41+
reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
42+
reference_type: Some(reference_segment::ReferenceType::StructField(
43+
Box::new(reference_segment::StructField {
44+
field: index as i32,
45+
child: None,
46+
}),
47+
)),
48+
})),
49+
root_type: Some(RootType::RootReference(RootReference {})),
50+
}))),
51+
})
52+
}
53+
54+
/// Try to convert an [Expr] to a [FieldReference].
55+
/// Returns `Err` if the [Expr] is not a [Expr::Column].
56+
pub(crate) fn try_to_substrait_field_reference(
57+
expr: &Expr,
58+
schema: &DFSchemaRef,
59+
) -> datafusion::common::Result<FieldReference> {
60+
match expr {
61+
Expr::Column(col) => {
62+
let index = schema.index_of_column(col)?;
63+
Ok(FieldReference {
64+
reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
65+
reference_type: Some(reference_segment::ReferenceType::StructField(
66+
Box::new(reference_segment::StructField {
67+
field: index as i32,
68+
child: None,
69+
}),
70+
)),
71+
})),
72+
root_type: Some(RootType::RootReference(RootReference {})),
73+
})
74+
}
75+
_ => substrait_err!("Expect a `Column` expr, but found {expr:?}"),
76+
}
77+
}
78+
79+
#[cfg(test)]
80+
mod tests {
81+
use super::*;
82+
use datafusion::common::Result;
83+
84+
#[test]
85+
fn to_field_reference() -> Result<()> {
86+
let expression = substrait_field_ref(2)?;
87+
88+
match &expression.rex_type {
89+
Some(RexType::Selection(field_ref)) => {
90+
assert_eq!(
91+
field_ref
92+
.root_type
93+
.clone()
94+
.expect("root type should be set"),
95+
RootType::RootReference(RootReference {})
96+
);
97+
}
98+
99+
_ => panic!("Should not be anything other than field reference"),
100+
}
101+
Ok(())
102+
}
103+
}

0 commit comments

Comments
 (0)