Skip to content

Commit 6013968

Browse files
authored
Merge pull request #221 from influxdata/crepererum/sandbox-dt-edition
feat: check DataType-related conversion
2 parents b7fd8e2 + 3e1e12c commit 6013968

File tree

8 files changed

+561
-17
lines changed

8 files changed

+561
-17
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//! UDF with many inputs.
2+
3+
use std::sync::Arc;
4+
5+
use arrow::datatypes::DataType;
6+
use datafusion_common::{DataFusionError, Result as DataFusionResult};
7+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
8+
9+
/// UDF with specific signature.
10+
#[derive(Debug, PartialEq, Eq, Hash)]
11+
struct SignatureUDF {
12+
/// Name.
13+
name: &'static str,
14+
15+
/// The signature.
16+
signature: Signature,
17+
}
18+
19+
impl SignatureUDF {
20+
/// Create new UDF.
21+
fn new(name: &'static str, signature: Signature) -> Self {
22+
Self { name, signature }
23+
}
24+
}
25+
26+
impl ScalarUDFImpl for SignatureUDF {
27+
fn as_any(&self) -> &dyn std::any::Any {
28+
self
29+
}
30+
31+
fn name(&self) -> &str {
32+
self.name
33+
}
34+
35+
fn signature(&self) -> &Signature {
36+
&self.signature
37+
}
38+
39+
fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> {
40+
Ok(DataType::Null)
41+
}
42+
43+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
44+
Err(DataFusionError::NotImplemented(
45+
"invoke_with_args".to_owned(),
46+
))
47+
}
48+
}
49+
50+
/// Returns our evil UDFs.
51+
///
52+
/// The passed `source` is ignored.
53+
#[expect(clippy::unnecessary_wraps, reason = "public API through export! macro")]
54+
pub(crate) fn udfs(_source: String) -> DataFusionResult<Vec<Arc<dyn ScalarUDFImpl>>> {
55+
let limit: usize = std::env::var("limit").unwrap().parse().unwrap();
56+
57+
Ok(vec![Arc::new(SignatureUDF::new(
58+
"input_count",
59+
Signature::exact(vec![DataType::Null; limit + 1], Volatility::Immutable),
60+
))])
61+
}

guests/evil/src/complex/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ use datafusion_expr::{
77
};
88

99
pub(crate) mod error;
10+
pub(crate) mod many_inputs;
11+
pub(crate) mod return_type;
12+
pub(crate) mod return_value;
1013
pub(crate) mod udf_long_name;
1114
pub(crate) mod udfs_duplicate_names;
1215
pub(crate) mod udfs_many;
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
//! Badness related to return data types.
2+
3+
use std::{collections::HashMap, sync::Arc};
4+
5+
use arrow::datatypes::{DataType, Field};
6+
use datafusion_common::{DataFusionError, Result as DataFusionResult};
7+
use datafusion_expr::{
8+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
9+
};
10+
11+
/// UDF that returns a specific data type.
12+
#[derive(Debug, PartialEq, Eq, Hash)]
13+
struct ReturnTypeUDF {
14+
/// Name.
15+
name: &'static str,
16+
17+
/// The return type
18+
return_type: DataType,
19+
}
20+
21+
impl ReturnTypeUDF {
22+
/// Create new UDF.
23+
fn new(name: &'static str, return_type: DataType) -> Self {
24+
Self { name, return_type }
25+
}
26+
}
27+
28+
impl ScalarUDFImpl for ReturnTypeUDF {
29+
fn as_any(&self) -> &dyn std::any::Any {
30+
self
31+
}
32+
33+
fn name(&self) -> &str {
34+
self.name
35+
}
36+
37+
fn signature(&self) -> &Signature {
38+
static S: Signature = Signature {
39+
type_signature: TypeSignature::Uniform(0, vec![]),
40+
volatility: Volatility::Immutable,
41+
};
42+
43+
&S
44+
}
45+
46+
fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> {
47+
Ok(self.return_type.clone())
48+
}
49+
50+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
51+
Err(DataFusionError::NotImplemented(
52+
"invoke_with_args".to_owned(),
53+
))
54+
}
55+
}
56+
57+
/// Returns our evil UDFs.
58+
///
59+
/// The passed `source` is ignored.
60+
#[expect(clippy::unnecessary_wraps, reason = "public API through export! macro")]
61+
pub(crate) fn udfs(_source: String) -> DataFusionResult<Vec<Arc<dyn ScalarUDFImpl>>> {
62+
let max_identifier_length: usize = std::env::var("max_identifier_length")
63+
.unwrap()
64+
.parse()
65+
.unwrap();
66+
let max_aux_string_length: usize = std::env::var("max_aux_string_length")
67+
.unwrap()
68+
.parse()
69+
.unwrap();
70+
let max_depth: usize = std::env::var("max_depth").unwrap().parse().unwrap();
71+
let max_complexity: usize = std::env::var("max_complexity").unwrap().parse().unwrap();
72+
73+
Ok(vec![
74+
Arc::new(ReturnTypeUDF::new(
75+
"dt_depth",
76+
(0..=max_depth).fold(DataType::Int64, |dt, _| {
77+
DataType::List(Arc::new(Field::new("f", dt, true)))
78+
}),
79+
)),
80+
Arc::new(ReturnTypeUDF::new(
81+
"field_name",
82+
DataType::List(Arc::new(Field::new(
83+
std::iter::repeat_n('x', max_identifier_length + 1).collect::<String>(),
84+
DataType::Int64,
85+
true,
86+
))),
87+
)),
88+
Arc::new(ReturnTypeUDF::new(
89+
"field_md_key",
90+
DataType::List(Arc::new(
91+
Field::new("f", DataType::Int64, true).with_metadata(HashMap::from([(
92+
std::iter::repeat_n('x', max_identifier_length + 1).collect::<String>(),
93+
"value".to_owned(),
94+
)])),
95+
)),
96+
)),
97+
Arc::new(ReturnTypeUDF::new(
98+
"field_md_value",
99+
DataType::List(Arc::new(
100+
Field::new("f", DataType::Int64, true).with_metadata(HashMap::from([(
101+
"key".to_owned(),
102+
std::iter::repeat_n('x', max_aux_string_length + 1).collect::<String>(),
103+
)])),
104+
)),
105+
)),
106+
Arc::new(ReturnTypeUDF::new(
107+
"field_md_items",
108+
DataType::List(Arc::new(
109+
Field::new("f", DataType::Int64, true).with_metadata(
110+
(0..=max_complexity)
111+
.map(|idx| (idx.to_string(), "value".to_owned()))
112+
.collect(),
113+
),
114+
)),
115+
)),
116+
])
117+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//! Payloads that return complex values when being invoked.
2+
3+
use std::sync::Arc;
4+
5+
use arrow::{
6+
array::{Int64Array, StructArray},
7+
datatypes::{DataType, Field},
8+
};
9+
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue};
10+
use datafusion_expr::{
11+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
12+
};
13+
14+
use crate::common::DynBox;
15+
16+
/// UDF that returns a specific data type.
17+
#[derive(Debug, PartialEq, Eq, Hash)]
18+
struct ReturnValueUDF {
19+
/// Name.
20+
name: &'static str,
21+
22+
/// The return value
23+
return_value: DynBox<ColumnarValue>,
24+
}
25+
26+
impl ReturnValueUDF {
27+
/// Create new UDF.
28+
fn new(name: &'static str, return_value: ColumnarValue) -> Self {
29+
Self {
30+
name,
31+
return_value: DynBox(Box::new(return_value)),
32+
}
33+
}
34+
}
35+
36+
impl ScalarUDFImpl for ReturnValueUDF {
37+
fn as_any(&self) -> &dyn std::any::Any {
38+
self
39+
}
40+
41+
fn name(&self) -> &str {
42+
self.name
43+
}
44+
45+
fn signature(&self) -> &Signature {
46+
static S: Signature = Signature {
47+
type_signature: TypeSignature::Uniform(0, vec![]),
48+
volatility: Volatility::Immutable,
49+
};
50+
51+
&S
52+
}
53+
54+
fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> {
55+
Err(DataFusionError::NotImplemented("return_type".to_owned()))
56+
}
57+
58+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
59+
Ok(self.return_value.clone())
60+
}
61+
}
62+
63+
/// Returns our evil UDFs.
64+
///
65+
/// The passed `source` is ignored.
66+
#[expect(clippy::unnecessary_wraps, reason = "public API through export! macro")]
67+
pub(crate) fn udfs(_source: String) -> DataFusionResult<Vec<Arc<dyn ScalarUDFImpl>>> {
68+
let max_depth: usize = std::env::var("max_depth").unwrap().parse().unwrap();
69+
70+
Ok(vec![
71+
Arc::new(ReturnValueUDF::new(
72+
"dt_depth_array",
73+
ColumnarValue::Array((0..=max_depth).fold(
74+
Arc::new(Int64Array::new(vec![1].into(), None)),
75+
|a, _| {
76+
Arc::new(StructArray::from(vec![(
77+
Arc::new(Field::new("a", a.data_type().clone(), true)),
78+
a,
79+
)]))
80+
},
81+
)),
82+
)),
83+
Arc::new(ReturnValueUDF::new(
84+
"dt_depth_scalar",
85+
ColumnarValue::Scalar((0..=max_depth).fold(ScalarValue::Int64(Some(1)), |v, _| {
86+
ScalarValue::Struct(Arc::new(StructArray::from(vec![(
87+
Arc::new(Field::new("a", v.data_type(), true)),
88+
v.to_array().unwrap(),
89+
)])))
90+
})),
91+
)),
92+
])
93+
}

guests/evil/src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ impl Evil {
4040
root: Box::new(common::root_empty),
4141
udfs: Box::new(complex::error::udfs),
4242
},
43+
"complex::many_inputs" => Self {
44+
root: Box::new(common::root_empty),
45+
udfs: Box::new(complex::many_inputs::udfs),
46+
},
47+
"complex::return_type" => Self {
48+
root: Box::new(common::root_empty),
49+
udfs: Box::new(complex::return_type::udfs),
50+
},
51+
"complex::return_value" => Self {
52+
root: Box::new(common::root_empty),
53+
udfs: Box::new(complex::return_value::udfs),
54+
},
4355
"complex::udf_long_name" => Self {
4456
root: Box::new(common::root_empty),
4557
udfs: Box::new(complex::udf_long_name::udfs),

host/src/conversion/mod.rs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
//! Conversion routes from/to WIT types.
2+
use std::collections::HashMap;
3+
24
use arrow::{
35
array::ArrayRef,
46
datatypes::{DataType, Field, IntervalUnit, TimeUnit, UnionFields, UnionMode},
@@ -94,7 +96,7 @@ fn check_union_fields(
9496
let token = token.sub()?;
9597

9698
for (_idx, field) in ufields.iter() {
97-
check_field(field, &token)?;
99+
check_field(field, &token).context("field")?;
98100
}
99101
Ok(())
100102
}
@@ -167,26 +169,26 @@ fn check_data_type(
167169
| DataType::FixedSizeList(field, _)
168170
| DataType::LargeList(field)
169171
| DataType::LargeListView(field)
170-
| DataType::Map(field, _) => check_field(field, &token),
172+
| DataType::Map(field, _) => check_field(field, &token).context("field"),
171173
DataType::Struct(fields) => {
172-
for field in fields {
173-
check_field(field, &token)?;
174+
for (idx, field) in fields.iter().enumerate() {
175+
check_field(field, &token).with_context(|| format!("field {idx}"))?;
174176
}
175177
Ok(())
176178
}
177179
DataType::Union(ufields, umode) => {
178-
check_union_fields(ufields, &token)?;
179-
check_union_mode(umode, &token)?;
180+
check_union_fields(ufields, &token).context("union fields")?;
181+
check_union_mode(umode, &token).context("union mode")?;
180182
Ok(())
181183
}
182184
DataType::Dictionary(dt1, dt2) => {
183-
check_data_type(dt1, &token)?;
184-
check_data_type(dt2, &token)?;
185+
check_data_type(dt1, &token).context("key type")?;
186+
check_data_type(dt2, &token).context("value type")?;
185187
Ok(())
186188
}
187189
DataType::RunEndEncoded(field1, field2) => {
188-
check_field(field1, &token)?;
189-
check_field(field2, &token)?;
190+
check_field(field1, &token).context("REE run-ends field")?;
191+
check_field(field2, &token).context("REE value field")?;
190192
Ok(())
191193
}
192194
}
@@ -196,12 +198,24 @@ fn check_data_type(
196198
fn check_field(field: &Field, token: &limits::ComplexityToken) -> datafusion_common::Result<()> {
197199
let token = token.sub()?;
198200

199-
token.check_identifier(field.name())?;
200-
check_data_type(field.data_type(), &token)?;
201+
token.check_identifier(field.name()).context("field name")?;
202+
check_data_type(field.data_type(), &token).context("field data type")?;
203+
check_metadata(field.metadata(), &token).context("field metadata")?;
204+
205+
Ok(())
206+
}
207+
208+
/// Check metadata complexity.
209+
fn check_metadata(
210+
md: &HashMap<String, String>,
211+
token: &limits::ComplexityToken,
212+
) -> datafusion_common::Result<()> {
213+
let token = token.sub()?;
201214

202-
for (k, v) in field.metadata() {
203-
token.check_identifier(k)?;
204-
token.check_aux_string(v)?;
215+
for (k, v) in md {
216+
let token = token.sub()?;
217+
token.check_identifier(k).context("metadata key")?;
218+
token.check_aux_string(v).context("metadata value")?;
205219
}
206220

207221
Ok(())

host/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,8 @@ impl WasmScalarUdf {
620620
Some(&store_guard.data().stderr.contents()),
621621
)?;
622622
ComplexityToken::new(permissions.trusted_data_limits.clone())?
623-
.check_identifier(&name)?;
623+
.check_identifier(&name)
624+
.context("UDF name")?;
624625
if !names_seen.insert(name.clone()) {
625626
return Err(DataFusionError::External(
626627
format!("non-unique UDF name: '{name}'").into(),

0 commit comments

Comments
 (0)