Skip to content

Commit 9c6db36

Browse files
committed
feat: convert_array_to_scalar_vec returns optional arrays
1 parent b81073a commit 9c6db36

File tree

6 files changed

+127
-64
lines changed

6 files changed

+127
-64
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,6 +3246,8 @@ impl ScalarValue {
32463246

32473247
/// Retrieve ScalarValue for each row in `array`
32483248
///
3249+
/// Elements in `array` itself may be NULL, in which case the corresponding element in the returned vector is None.
3250+
///
32493251
/// Example 1: Array (ScalarValue::Int32)
32503252
/// ```
32513253
/// use datafusion_common::ScalarValue;
@@ -3262,15 +3264,15 @@ impl ScalarValue {
32623264
/// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
32633265
///
32643266
/// let expected = vec![
3265-
/// vec![
3266-
/// ScalarValue::Int32(Some(1)),
3267-
/// ScalarValue::Int32(Some(2)),
3268-
/// ScalarValue::Int32(Some(3)),
3269-
/// ],
3270-
/// vec![
3271-
/// ScalarValue::Int32(Some(4)),
3272-
/// ScalarValue::Int32(Some(5)),
3273-
/// ],
3267+
/// Some(vec![
3268+
/// ScalarValue::Int32(Some(1)),
3269+
/// ScalarValue::Int32(Some(2)),
3270+
/// ScalarValue::Int32(Some(3)),
3271+
/// ]),
3272+
/// Some(vec![
3273+
/// ScalarValue::Int32(Some(4)),
3274+
/// ScalarValue::Int32(Some(5)),
3275+
/// ]),
32743276
/// ];
32753277
///
32763278
/// assert_eq!(scalar_vec, expected);
@@ -3303,28 +3305,60 @@ impl ScalarValue {
33033305
/// ]);
33043306
///
33053307
/// let expected = vec![
3306-
/// vec![
3308+
/// Some(vec![
33073309
/// ScalarValue::List(Arc::new(l1)),
33083310
/// ScalarValue::List(Arc::new(l2)),
3309-
/// ],
3311+
/// ]),
3312+
/// ];
3313+
///
3314+
/// assert_eq!(scalar_vec, expected);
3315+
/// ```
3316+
///
3317+
/// Example 3: Nullable array
3318+
/// ```
3319+
/// use datafusion_common::ScalarValue;
3320+
/// use arrow::array::ListArray;
3321+
/// use arrow::datatypes::{DataType, Int32Type};
3322+
///
3323+
/// let list_arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
3324+
/// Some(vec![Some(1), Some(2), Some(3)]),
3325+
/// None,
3326+
/// Some(vec![Some(4), Some(5)])
3327+
/// ]);
3328+
///
3329+
/// // Convert the array into Scalar Values for each row
3330+
/// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
3331+
///
3332+
/// let expected = vec![
3333+
/// Some(vec![
3334+
/// ScalarValue::Int32(Some(1)),
3335+
/// ScalarValue::Int32(Some(2)),
3336+
/// ScalarValue::Int32(Some(3)),
3337+
/// ]),
3338+
/// None,
3339+
/// Some(vec![
3340+
/// ScalarValue::Int32(Some(4)),
3341+
/// ScalarValue::Int32(Some(5)),
3342+
/// ]),
33103343
/// ];
33113344
///
33123345
/// assert_eq!(scalar_vec, expected);
33133346
/// ```
3314-
pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result<Vec<Vec<Self>>> {
3347+
pub fn convert_array_to_scalar_vec(
3348+
array: &dyn Array,
3349+
) -> Result<Vec<Option<Vec<Self>>>> {
33153350
fn generic_collect<OffsetSize: OffsetSizeTrait>(
33163351
array: &dyn Array,
3317-
) -> Result<Vec<Vec<ScalarValue>>> {
3352+
) -> Result<Vec<Option<Vec<ScalarValue>>>> {
33183353
array
33193354
.as_list::<OffsetSize>()
33203355
.iter()
33213356
.map(|nested_array| match nested_array {
33223357
Some(nested_array) => (0..nested_array.len())
33233358
.map(|i| ScalarValue::try_from_array(&nested_array, i))
3324-
.collect::<Result<Vec<_>>>(),
3325-
// TODO: what can we put for null?
3326-
// https://github.com/apache/datafusion/issues/17749
3327-
None => Ok(vec![]),
3359+
.collect::<Result<Vec<_>>>()
3360+
.map(Some),
3361+
None => Ok(None),
33283362
})
33293363
.collect()
33303364
}
@@ -9031,13 +9065,16 @@ mod tests {
90319065
assert_eq!(
90329066
converted,
90339067
vec![
9034-
vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
9035-
vec![],
9036-
vec![
9068+
Some(vec![
9069+
ScalarValue::Int64(Some(1)),
9070+
ScalarValue::Int64(Some(2))
9071+
]),
9072+
None,
9073+
Some(vec![
90379074
ScalarValue::Int64(Some(3)),
90389075
ScalarValue::Int64(None),
90399076
ScalarValue::Int64(Some(4))
9040-
],
9077+
]),
90419078
]
90429079
);
90439080

@@ -9051,13 +9088,16 @@ mod tests {
90519088
assert_eq!(
90529089
converted,
90539090
vec![
9054-
vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
9055-
vec![],
9056-
vec![
9091+
Some(vec![
9092+
ScalarValue::Int64(Some(1)),
9093+
ScalarValue::Int64(Some(2))
9094+
]),
9095+
None,
9096+
Some(vec![
90579097
ScalarValue::Int64(Some(3)),
90589098
ScalarValue::Int64(None),
90599099
ScalarValue::Int64(Some(4))
9060-
],
9100+
]),
90619101
]
90629102
);
90639103

@@ -9074,9 +9114,12 @@ mod tests {
90749114
assert_eq!(
90759115
converted,
90769116
vec![
9077-
vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
9078-
vec![],
9079-
vec![ScalarValue::Int64(Some(5))],
9117+
Some(vec![
9118+
ScalarValue::Int64(Some(1)),
9119+
ScalarValue::Int64(Some(2))
9120+
]),
9121+
None,
9122+
Some(vec![ScalarValue::Int64(Some(5))]),
90809123
]
90819124
);
90829125
}

datafusion/core/tests/sql/aggregates/basic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
4848
let column = actual[0].column(0);
4949
assert_eq!(column.len(), 1);
5050
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?;
51-
let mut scalars = scalar_vec[0].clone();
51+
let mut scalars = scalar_vec[0].as_ref().unwrap().clone();
5252

5353
// workaround lack of Ord of ScalarValue
5454
let cmp = |a: &ScalarValue, b: &ScalarValue| {

datafusion/functions-aggregate-common/src/merge_arrays.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,25 @@ pub fn merge_ordered_arrays(
119119
// Defines according to which ordering comparisons should be done.
120120
sort_options: &[SortOptions],
121121
) -> datafusion_common::Result<(Vec<ScalarValue>, Vec<Vec<ScalarValue>>)> {
122-
// Keep track the most recent data of each branch, in binary heap data structure.
122+
// Keep track of the most recent data of each branch, in a binary heap data structure.
123123
let mut heap = BinaryHeap::<CustomElement>::new();
124124

125-
if values.len() != ordering_values.len()
126-
|| values
127-
.iter()
128-
.zip(ordering_values.iter())
129-
.any(|(vals, ordering_vals)| vals.len() != ordering_vals.len())
125+
if values.len() != ordering_values.len() {
126+
return exec_err!(
127+
"Expects values and ordering_values to have same size but got: values.len() = {}, ordering_values.len() = {}",
128+
values.len(),
129+
ordering_values.len()
130+
);
131+
}
132+
if values
133+
.iter()
134+
.zip(ordering_values.iter())
135+
.any(|(vals, ordering_vals)| vals.len() != ordering_vals.len())
130136
{
131137
return exec_err!(
132-
"Expects values arguments and/or ordering_values arguments to have same size"
138+
"Expects values elements and ordering_values elements to have same size but got: values.len() = {}, ordering_values.len() = {}",
139+
values.len(),
140+
ordering_values.len()
133141
);
134142
}
135143
let n_branch = values.len();

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -687,15 +687,19 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
687687

688688
// Convert array to Scalars to sort them easily. Convert back to array at evaluation.
689689
let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
690-
for v in array_agg_res.into_iter() {
691-
partition_values.push(v.into());
690+
for maybe_v in array_agg_res.into_iter() {
691+
if let Some(v) = maybe_v {
692+
partition_values.push(v.into());
693+
} else {
694+
partition_values.push(vec![].into());
695+
}
692696
}
693697

694698
let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
695-
696-
for partition_ordering_rows in orderings.into_iter() {
697-
// Extract value from struct to ordering_rows for each group/partition
698-
let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
699+
for maybe_partition_ordering_rows in orderings.into_iter() {
700+
if let Some(partition_ordering_rows) = maybe_partition_ordering_rows {
701+
// Extract value from struct to ordering_rows for each group/partition
702+
let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
699703
if let ScalarValue::Struct(s) = ordering_row {
700704
let mut ordering_columns_per_row = vec![];
701705

@@ -713,7 +717,8 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
713717
}
714718
}).collect::<Result<VecDeque<_>>>()?;
715719

716-
partition_ordering_values.push(ordering_value);
720+
partition_ordering_values.push(ordering_value);
721+
}
717722
}
718723

719724
let sort_options = self

datafusion/functions-aggregate/src/nth_value.rs

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,13 @@ impl Accumulator for TrivialNthValueAccumulator {
267267
// First entry in the state is the aggregation result.
268268
let n_required = self.n.unsigned_abs() as usize;
269269
let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
270-
for v in array_agg_res.into_iter() {
271-
self.values.extend(v);
272-
if self.values.len() > n_required {
273-
// There is enough data collected, can stop merging:
274-
break;
270+
for maybe_v in array_agg_res.into_iter() {
271+
if let Some(v) = maybe_v {
272+
self.values.extend(v);
273+
if self.values.len() > n_required {
274+
// There is enough data collected, can stop merging:
275+
break;
276+
}
275277
}
276278
}
277279
}
@@ -458,27 +460,31 @@ impl Accumulator for NthValueAccumulator {
458460
// First entry in the state is the aggregation result.
459461
let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
460462
for v in array_agg_res.into_iter() {
461-
partition_values.push(v.into());
463+
if let Some(v) = v {
464+
partition_values.push(v.into());
465+
}
462466
}
463467
// Stores ordering requirement expression results coming from each partition:
464468
let mut partition_ordering_values = vec![self.ordering_values.clone()];
465469
let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
466470
// Extract value from struct to ordering_rows for each group/partition:
467-
for partition_ordering_rows in orderings.into_iter() {
468-
let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| {
469-
let ScalarValue::Struct(s_array) = ordering_row else {
470-
return exec_err!(
471-
"Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
472-
ordering_row.data_type()
473-
);
474-
};
475-
s_array
476-
.columns()
477-
.iter()
478-
.map(|column| ScalarValue::try_from_array(column, 0))
479-
.collect()
480-
}).collect::<Result<VecDeque<_>>>()?;
481-
partition_ordering_values.push(ordering_values);
471+
for maybe_partition_ordering_rows in orderings.into_iter() {
472+
if let Some(partition_ordering_rows) = maybe_partition_ordering_rows {
473+
let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| {
474+
let ScalarValue::Struct(s_array) = ordering_row else {
475+
return exec_err!(
476+
"Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
477+
ordering_row.data_type()
478+
);
479+
};
480+
s_array
481+
.columns()
482+
.iter()
483+
.map(|column| ScalarValue::try_from_array(column, 0))
484+
.collect()
485+
}).collect::<Result<VecDeque<_>>>()?;
486+
partition_ordering_values.push(ordering_values);
487+
}
482488
}
483489

484490
let sort_options = self

datafusion/functions-nested/src/array_has.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ impl ScalarUDFImpl for ArrayHas {
145145
let list = scalar_values
146146
.into_iter()
147147
.flatten()
148+
.flatten()
148149
.map(|v| Expr::Literal(v, None))
149150
.collect();
150151

0 commit comments

Comments
 (0)