diff --git a/datafusion/functions-nested/src/except.rs b/datafusion/functions-nested/src/except.rs index a8ac997ce33ec..6d66165760e2e 100644 --- a/datafusion/functions-nested/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_except function. +//! [`ScalarUDFImpl`] definition for array_except function. use crate::utils::{check_datatypes, make_scalar_function}; +use arrow::array::new_null_array; use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait, cast::AsArray}; -use arrow::buffer::OffsetBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::datatypes::{DataType, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::utils::{ListCoercion, take_function_args}; @@ -28,6 +29,7 @@ use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; +use itertools::Itertools; use std::any::Any; use std::sync::Arc; @@ -104,8 +106,11 @@ impl ScalarUDFImpl for ArrayExcept { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (&arg_types[0].clone(), &arg_types[1].clone()) { - (DataType::Null, _) | (_, DataType::Null) => Ok(arg_types[0].clone()), + match (&arg_types[0], &arg_types[1]) { + (DataType::Null, DataType::Null) => { + Ok(DataType::new_list(DataType::Null, true)) + } + (DataType::Null, dt) | (dt, DataType::Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -129,8 +134,16 @@ impl ScalarUDFImpl for ArrayExcept { fn array_except_inner(args: &[ArrayRef]) -> Result { let [array1, array2] = take_function_args("array_except", args)?; + let len = array1.len(); match (array1.data_type(), array2.data_type()) { - (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), + (DataType::Null, DataType::Null) => Ok(new_null_array( + &DataType::new_list(DataType::Null, true), + len, + )), + (DataType::Null, dt @ DataType::List(_)) + | (DataType::Null, dt @ DataType::LargeList(_)) + | (dt @ DataType::List(_), DataType::Null) + | (dt @ DataType::LargeList(_), DataType::Null) => Ok(new_null_array(dt, len)), (DataType::List(field), DataType::List(_)) => { check_datatypes("array_except", &[array1, array2])?; let list1 = array1.as_list::(); @@ -169,14 +182,25 @@ fn general_except( let mut rows = Vec::with_capacity(l_values.num_rows()); let mut dedup = HashSet::new(); - for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { - let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); - let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); - for i in r_slice { + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + + for (i, ((l_start, l_end), (r_start, r_end))) in l + .offsets() + .iter() + .tuple_windows() + .zip(r.offsets().iter().tuple_windows()) + .enumerate() + { + if nulls.as_ref().is_some_and(|nulls| nulls.is_null(i)) { + offsets.push(OffsetSize::usize_as(rows.len())); + continue; + } + + for i in r_start.as_usize()..r_end.as_usize() { let right_row = r_values.row(i); dedup.insert(right_row); } - for i in l_slice { + for i in l_start.as_usize()..l_end.as_usize() { let left_row = l_values.row(i); if dedup.insert(left_row) { rows.push(left_row); @@ -192,7 +216,7 @@ fn general_except( field.to_owned(), OffsetBuffer::new(offsets.into()), values.to_owned(), - l.nulls().cloned(), + nulls, )) } else { internal_err!("array_except failed to convert rows") diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 69a220e125c04..8799af6d491c2 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -19,8 +19,7 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, GenericListArray, LargeListArray, ListArray, OffsetSizeTrait, - new_null_array, + Array, ArrayRef, GenericListArray, OffsetSizeTrait, new_empty_array, new_null_array, }; use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute; @@ -69,7 +68,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.", + description = "Returns an array of elements that are present in both arrays (all elements from both arrays) without duplicates.", syntax_example = "array_union(array1, array2)", sql_example = r#"```sql > select array_union([1, 2, 3, 4], [5, 6, 3, 4]); @@ -136,8 +135,7 @@ impl ScalarUDFImpl for ArrayUnion { let [array1, array2] = take_function_args(self.name(), arg_types)?; match (array1, array2) { (Null, Null) => Ok(DataType::new_list(Null, true)), - (Null, dt) => Ok(dt.clone()), - (dt, Null) => Ok(dt.clone()), + (Null, dt) | (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -221,8 +219,7 @@ impl ScalarUDFImpl for ArrayIntersect { let [array1, array2] = take_function_args(self.name(), arg_types)?; match (array1, array2) { (Null, Null) => Ok(DataType::new_list(Null, true)), - (Null, dt) => Ok(dt.clone()), - (dt, Null) => Ok(dt.clone()), + (Null, dt) | (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -363,23 +360,19 @@ fn generic_set_lists( let mut offsets = vec![OffsetSize::usize_as(0)]; let mut new_arrays = vec![]; - let mut new_null_buf = vec![]; let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; - for (first_arr, second_arr) in l.iter().zip(r.iter()) { - let mut ele_should_be_null = false; + for (l_arr, r_arr) in l.iter().zip(r.iter()) { + let last_offset = *offsets.last().unwrap(); - let l_values = if let Some(first_arr) = first_arr { - converter.convert_columns(&[first_arr])? - } else { - ele_should_be_null = true; - converter.empty_rows(0, 0) - }; - - let r_values = if let Some(second_arr) = second_arr { - converter.convert_columns(&[second_arr])? - } else { - ele_should_be_null = true; - converter.empty_rows(0, 0) + let (l_values, r_values) = match (l_arr, r_arr) { + (Some(l_arr), Some(r_arr)) => ( + converter.convert_columns(&[l_arr])?, + converter.convert_columns(&[r_arr])?, + ), + _ => { + offsets.push(last_offset); + continue; + } }; let l_iter = l_values.iter().sorted().dedup(); @@ -405,11 +398,6 @@ fn generic_set_lists( } } - let last_offset = match offsets.last() { - Some(offset) => *offset, - None => return internal_err!("offsets should not be empty"), - }; - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); let arrays = converter.convert_rows(rows)?; let array = match arrays.first() { @@ -419,18 +407,21 @@ fn generic_set_lists( } }; - new_null_buf.push(!ele_should_be_null); new_arrays.push(array); } let offsets = OffsetBuffer::new(offsets.into()); let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect(); - let values = compute::concat(&new_arrays_ref)?; + let values = if new_arrays_ref.is_empty() { + new_empty_array(&l.value_type()) + } else { + compute::concat(&new_arrays_ref)? + }; let arr = GenericListArray::::try_new( field, offsets, values, - Some(NullBuffer::new(new_null_buf.into())), + NullBuffer::union(l.nulls(), r.nulls()), )?; Ok(Arc::new(arr)) } @@ -440,59 +431,13 @@ fn general_set_op( array2: &ArrayRef, set_op: SetOp, ) -> Result { - fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result { - let field = Arc::new(Field::new_list_field(data_type.clone(), true)); - let values = new_null_array(data_type, len); - if large { - Ok(Arc::new(LargeListArray::try_new( - field, - OffsetBuffer::new_zeroed(len), - values, - None, - )?)) - } else { - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new_zeroed(len), - values, - None, - )?)) - } - } - + let len = array1.len(); match (array1.data_type(), array2.data_type()) { - (Null, Null) => Ok(Arc::new(ListArray::new_null( - Arc::new(Field::new_list_field(Null, true)), - array1.len(), - ))), - (Null, List(field)) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), false); - } - let array = as_list_array(&array2)?; - general_array_distinct::(array, field) - } - (List(field), Null) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), false); - } - let array = as_list_array(&array1)?; - general_array_distinct::(array, field) - } - (Null, LargeList(field)) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), true); - } - let array = as_large_list_array(&array2)?; - general_array_distinct::(array, field) - } - (LargeList(field), Null) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), true); - } - let array = as_large_list_array(&array1)?; - general_array_distinct::(array, field) - } + (Null, Null) => Ok(new_null_array(&DataType::new_list(Null, true), len)), + (Null, dt @ List(_)) + | (Null, dt @ LargeList(_)) + | (dt @ List(_), Null) + | (dt @ LargeList(_), Null) => Ok(new_null_array(dt, len)), (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c31f3d0702358..e17322f0fe013 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -4770,12 +4770,12 @@ select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([[ query ? select array_union(null, []); ---- -[] +NULL query ? select array_union(null, arrow_cast([], 'LargeList(Int64)')); ---- -[] +NULL # array_union scalar function #10 query ? @@ -4787,23 +4787,23 @@ NULL query ? select array_union([1, 1, 2, 2, 3, 3], null); ---- -[1, 2, 3] +NULL query ? select array_union(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); ---- -[1, 2, 3] +NULL # array_union scalar function #12 query ? select array_union(null, [1, 1, 2, 2, 3, 3]); ---- -[1, 2, 3] +NULL query ? select array_union(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -[1, 2, 3] +NULL # array_union scalar function #13 query ? @@ -4838,6 +4838,36 @@ NULL NULL NULL +query ? +select array_union(arrow_cast(null, 'List(Int64)'), [1, 2]); +---- +NULL + +query ? +select array_union([1, 2], arrow_cast(null, 'List(Int64)')); +---- +NULL + +query ? +select array_intersect(arrow_cast(null, 'List(Int64)'), [1, 2]); +---- +NULL + +query ? +select array_intersect([1, 2], arrow_cast(null, 'List(Int64)')); +---- +NULL + +query ? +select array_except(arrow_cast(null, 'List(Int64)'), [1, 2]); +---- +NULL + +query ? +select array_except([1, 2], arrow_cast(null, 'List(Int64)')); +---- +NULL + # list_to_string scalar function #4 (function alias `array_to_string`) query TTT select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); @@ -6888,27 +6918,27 @@ select array_intersect(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'Large query ? select array_intersect([1, 1, 2, 2, 3, 3], null); ---- -[] +NULL query ? select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); ---- -[] +NULL query ? select array_intersect(null, [1, 1, 2, 2, 3, 3]); ---- -[] +NULL query ? select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -[] +NULL query ? select array_intersect([], null); ---- -[] +NULL query ? select array_intersect([[1,2,3]], [[]]); @@ -6923,17 +6953,17 @@ select array_intersect([[null]], [[]]); query ? select array_intersect(arrow_cast([], 'LargeList(Int64)'), null); ---- -[] +NULL query ? select array_intersect(null, []); ---- -[] +NULL query ? select array_intersect(null, arrow_cast([], 'LargeList(Int64)')); ---- -[] +NULL query ? select array_intersect(null, null); @@ -7476,7 +7506,7 @@ select array_except(column1, column2) from array_except_table; [2] [] NULL -[1, 2] +NULL NULL statement ok @@ -7497,7 +7527,7 @@ select array_except(column1, column2) from array_except_nested_list_table; ---- [[1, 2]] [[3]] -[[1, 2], [3]] +NULL NULL [] @@ -7536,7 +7566,7 @@ select array_except(column1, column2) from array_except_table_ut8; ---- [b, c] [a, bc] -[a, bc, def] +NULL NULL statement ok @@ -7558,7 +7588,7 @@ select array_except(column1, column2) from array_except_table_bool; [true] [true] [false] -[true, false] +NULL NULL statement ok @@ -7567,7 +7597,7 @@ drop table array_except_table_bool; query ? select array_except([], null); ---- -[] +NULL query ? select array_except([], []); diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index fe1ed1cab6bd7..125b2ac26f20d 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -4211,7 +4211,7 @@ array_to_string(array, delimiter[, null_string]) ### `array_union` -Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. +Returns an array of elements that are present in both arrays (all elements from both arrays) without duplicates. ```sql array_union(array1, array2)