diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 87344914d2f7e..4860393cab2b7 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -2248,7 +2248,7 @@ impl TableOptions { /// Options that control how Parquet files are read, including global options /// that apply to all columns and optional column-specific overrides /// -/// Closely tied to [`ParquetWriterOptions`](crate::file_options::parquet_writer::ParquetWriterOptions). +/// Closely tied to `ParquetWriterOptions` (see `crate::file_options::parquet_writer::ParquetWriterOptions` when the "parquet" feature is enabled). /// Properties not included in [`TableParquetOptions`] may not be configurable at the external API /// (e.g. sorting_columns). #[derive(Clone, Default, Debug, PartialEq)] diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index 086d96e85230d..98eaf00a328cf 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -19,9 +19,9 @@ use crate::error::{_plan_err, Result}; use arrow::{ array::{Array, ArrayRef, StructArray, new_null_array}, compute::{CastOptions, cast_with_options}, - datatypes::{DataType::Struct, Field, FieldRef}, + datatypes::{DataType, DataType::Struct, Field, FieldRef}, }; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; /// Cast a struct column to match target struct fields, handling nested structs recursively. /// @@ -31,6 +31,7 @@ use std::sync::Arc; /// /// ## Field Matching Strategy /// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive) +/// - **By Position**: When there is no name overlap and the field counts match, fields are cast by index /// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type /// - **Missing Fields**: Target fields not present in the source are filled with null values /// - **Extra Fields**: Source fields not present in the target are ignored @@ -54,16 +55,38 @@ fn cast_struct_column( target_fields: &[Arc], cast_options: &CastOptions, ) -> Result { + if source_col.data_type() == &DataType::Null + || (!source_col.is_empty() && source_col.null_count() == source_col.len()) + { + return Ok(new_null_array( + &Struct(target_fields.to_vec().into()), + source_col.len(), + )); + } + if let Some(source_struct) = source_col.as_any().downcast_ref::() { - validate_struct_compatibility(source_struct.fields(), target_fields)?; + let source_fields = source_struct.fields(); + let has_overlap = fields_have_name_overlap(source_fields, target_fields); + validate_struct_compatibility(source_fields, target_fields)?; let mut fields: Vec> = Vec::with_capacity(target_fields.len()); let mut arrays: Vec = Vec::with_capacity(target_fields.len()); let num_rows = source_col.len(); - for target_child_field in target_fields { + // Iterate target fields and pick source child either by name (when fields overlap) + // or by position (when there is no name overlap). + for (index, target_child_field) in target_fields.iter().enumerate() { fields.push(Arc::clone(target_child_field)); - match source_struct.column_by_name(target_child_field.name()) { + + // Determine the source child column: by name when overlapping names exist, + // otherwise by position. + let source_child_opt: Option<&ArrayRef> = if has_overlap { + source_struct.column_by_name(target_child_field.name()) + } else { + Some(source_struct.column(index)) + }; + + match source_child_opt { Some(source_child_col) => { let adapted_child = cast_column(source_child_col, target_child_field, cast_options) @@ -155,6 +178,15 @@ pub fn cast_column( ) -> Result { match target_field.data_type() { Struct(target_fields) => { + if source_col.data_type() == &DataType::Null + || (!source_col.is_empty() && source_col.null_count() == source_col.len()) + { + return Ok(new_null_array( + &Struct(target_fields.to_vec().into()), + source_col.len(), + )); + } + cast_struct_column(source_col, target_fields, cast_options) } _ => Ok(cast_with_options( @@ -165,6 +197,32 @@ pub fn cast_column( } } +/// Cast a struct array to another struct type by aligning child arrays using +/// field names instead of their physical order. +/// +/// This is a convenience wrapper around the internal function `cast_struct_column` that accepts +/// `Fields` directly instead of requiring a `Field` wrapper. +/// +/// See [`cast_column`] for detailed documentation on the casting behavior. +/// +/// # Arguments +/// * `array` - The source array to cast (must be a struct array) +/// * `target_fields` - The target struct field definitions +/// * `cast_options` - Options controlling cast behavior (strictness, formatting) +/// +/// # Returns +/// A `Result` containing the cast struct array +/// +/// # Errors +/// Returns an error if the source is not a struct array or if field casting fails +pub fn cast_struct_array_by_name( + array: &ArrayRef, + target_fields: &arrow::datatypes::Fields, + cast_options: &CastOptions, +) -> Result { + cast_struct_column(array, target_fields.as_ref(), cast_options) +} + /// Validates compatibility between source and target struct fields for casting operations. /// /// This function implements comprehensive struct compatibility checking by examining: @@ -204,6 +262,24 @@ pub fn validate_struct_compatibility( source_fields: &[FieldRef], target_fields: &[FieldRef], ) -> Result<()> { + let has_overlap = fields_have_name_overlap(source_fields, target_fields); + if !has_overlap { + if source_fields.len() != target_fields.len() { + return _plan_err!( + "Cannot cast struct with {} fields to {} fields without name overlap; positional mapping is ambiguous", + source_fields.len(), + target_fields.len() + ); + } + + for (source_field, target_field) in source_fields.iter().zip(target_fields.iter()) + { + validate_field_compatibility(source_field, target_field)?; + } + + return Ok(()); + } + // Check compatibility for each target field for target_field in target_fields { // Look for matching field in source by name @@ -211,44 +287,71 @@ pub fn validate_struct_compatibility( .iter() .find(|f| f.name() == target_field.name()) { - // Ensure nullability is compatible. It is invalid to cast a nullable - // source field to a non-nullable target field as this may discard - // null values. - if source_field.is_nullable() && !target_field.is_nullable() { + validate_field_compatibility(source_field, target_field)?; + } + // Missing fields in source are OK - they'll be filled with nulls + } + + // Extra fields in source are OK - they'll be ignored + Ok(()) +} + +fn validate_field_compatibility( + source_field: &Field, + target_field: &Field, +) -> Result<()> { + if source_field.data_type() == &DataType::Null { + return Ok(()); + } + + // Ensure nullability is compatible. It is invalid to cast a nullable + // source field to a non-nullable target field as this may discard + // null values. + if source_field.is_nullable() && !target_field.is_nullable() { + return _plan_err!( + "Cannot cast nullable struct field '{}' to non-nullable field", + target_field.name() + ); + } + + // Check if the matching field types are compatible + match (source_field.data_type(), target_field.data_type()) { + // Recursively validate nested structs + (Struct(source_nested), Struct(target_nested)) => { + validate_struct_compatibility(source_nested, target_nested)?; + } + // For non-struct types, use the existing castability check + _ => { + if !arrow::compute::can_cast_types( + source_field.data_type(), + target_field.data_type(), + ) { return _plan_err!( - "Cannot cast nullable struct field '{}' to non-nullable field", - target_field.name() + "Cannot cast struct field '{}' from type {} to type {}", + target_field.name(), + source_field.data_type(), + target_field.data_type() ); } - // Check if the matching field types are compatible - match (source_field.data_type(), target_field.data_type()) { - // Recursively validate nested structs - (Struct(source_nested), Struct(target_nested)) => { - validate_struct_compatibility(source_nested, target_nested)?; - } - // For non-struct types, use the existing castability check - _ => { - if !arrow::compute::can_cast_types( - source_field.data_type(), - target_field.data_type(), - ) { - return _plan_err!( - "Cannot cast struct field '{}' from type {} to type {}", - target_field.name(), - source_field.data_type(), - target_field.data_type() - ); - } - } - } } - // Missing fields in source are OK - they'll be filled with nulls } - // Extra fields in source are OK - they'll be ignored Ok(()) } +fn fields_have_name_overlap( + source_fields: &[FieldRef], + target_fields: &[FieldRef], +) -> bool { + let source_names: HashSet<&str> = source_fields + .iter() + .map(|field| field.name().as_str()) + .collect(); + target_fields + .iter() + .any(|field| source_names.contains(field.name().as_str())) +} + #[cfg(test)] mod tests { @@ -257,7 +360,7 @@ mod tests { use arrow::{ array::{ BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, - MapBuilder, StringArray, StringBuilder, + MapBuilder, NullArray, StringArray, StringBuilder, }, buffer::NullBuffer, datatypes::{DataType, Field, FieldRef, Int32Type}, @@ -428,11 +531,14 @@ mod tests { #[test] fn test_validate_struct_compatibility_missing_field_in_source() { - // Source struct: {field2: String} (missing field1) - let source_fields = vec![arc_field("field2", DataType::Utf8)]; + // Source struct: {field1: Int32} (missing field2) + let source_fields = vec![arc_field("field1", DataType::Int32)]; - // Target struct: {field1: Int32} - let target_fields = vec![arc_field("field1", DataType::Int32)]; + // Target struct: {field1: Int32, field2: Utf8} + let target_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; // Should be OK - missing fields will be filled with nulls let result = validate_struct_compatibility(&source_fields, &target_fields); @@ -455,6 +561,20 @@ mod tests { assert!(result.is_ok()); } + #[test] + fn test_validate_struct_compatibility_positional_no_overlap_mismatch_len() { + let source_fields = vec![ + arc_field("left", DataType::Int32), + arc_field("right", DataType::Int32), + ]; + let target_fields = vec![arc_field("alpha", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("positional mapping is ambiguous")); + } + #[test] fn test_cast_struct_parent_nulls_retained() { let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; @@ -585,6 +705,33 @@ mod tests { assert!(missing.is_null(1)); } + #[test] + fn test_cast_null_struct_field_to_nested_struct() { + let null_inner = Arc::new(NullArray::new(2)) as ArrayRef; + let source_struct = StructArray::from(vec![( + arc_field("inner", DataType::Null), + Arc::clone(&null_inner), + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "outer", + vec![struct_field("inner", vec![field("a", DataType::Int32)])], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let outer = result.as_any().downcast_ref::().unwrap(); + let inner = get_column_as!(&outer, "inner", StructArray); + assert_eq!(inner.len(), 2); + assert!(inner.is_null(0)); + assert!(inner.is_null(1)); + + let inner_a = get_column_as!(inner, "a", Int32Array); + assert!(inner_a.is_null(0)); + assert!(inner_a.is_null(1)); + } + #[test] fn test_cast_struct_with_array_and_map_fields() { // Array field with second row null @@ -704,4 +851,34 @@ mod tests { assert_eq!(a_col.value(0), 1); assert_eq!(a_col.value(1), 2); } + + #[test] + fn test_cast_struct_positional_when_no_overlap() { + let first = Arc::new(Int32Array::from(vec![Some(10), Some(20)])) as ArrayRef; + let second = + Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) as ArrayRef; + + let source_struct = StructArray::from(vec![ + (arc_field("left", DataType::Int32), first), + (arc_field("right", DataType::Utf8), second), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("a", DataType::Int64), field("b", DataType::Utf8)], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let a_col = get_column_as!(&struct_array, "a", Int64Array); + assert_eq!(a_col.value(0), 10); + assert_eq!(a_col.value(1), 20); + + let b_col = get_column_as!(&struct_array, "b", StringArray); + assert_eq!(b_col.value(0), "alpha"); + assert_eq!(b_col.value(1), "beta"); + } } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index eda4952cf590b..911094d7ddf67 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -3704,7 +3704,19 @@ impl ScalarValue { } let scalar_array = self.to_array()?; - let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; + + // Use name-based struct casting for struct types + let cast_arr = match (scalar_array.data_type(), target_type) { + (DataType::Struct(_), DataType::Struct(target_fields)) => { + crate::nested_struct::cast_struct_array_by_name( + &scalar_array, + target_fields, + cast_options, + )? + } + _ => cast_with_options(&scalar_array, target_type, cast_options)?, + }; + ScalarValue::try_from_array(&cast_arr, 0) } diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 99c21d4abdb6e..d3a216cee9acc 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -274,7 +274,17 @@ impl ColumnarValue { Ok(args) } - /// Cast's this [ColumnarValue] to the specified `DataType` + /// Cast this [ColumnarValue] to the specified `DataType` + /// + /// # Struct Casting Behavior + /// + /// When casting struct types, fields are matched **by name** rather than position: + /// - Source fields are matched to target fields using case-sensitive name comparison + /// - Fields are reordered to match the target schema + /// - Missing target fields are filled with null arrays + /// - Extra source fields are ignored + /// + /// For non-struct types, uses Arrow's standard positional casting. pub fn cast_to( &self, cast_type: &DataType, @@ -283,12 +293,8 @@ impl ColumnarValue { let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS); match self { ColumnarValue::Array(array) => { - ensure_date_array_timestamp_bounds(array, cast_type)?; - Ok(ColumnarValue::Array(kernels::cast::cast_with_options( - array, - cast_type, - &cast_options, - )?)) + let casted = cast_array_by_name(array, cast_type, &cast_options)?; + Ok(ColumnarValue::Array(casted)) } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( scalar.cast_to_with_options(cast_type, &cast_options)?, @@ -297,6 +303,35 @@ impl ColumnarValue { } } +fn cast_array_by_name( + array: &ArrayRef, + cast_type: &DataType, + cast_options: &CastOptions<'static>, +) -> Result { + // If types are already equal, no cast needed + if array.data_type() == cast_type { + return Ok(Arc::clone(array)); + } + + match (array.data_type(), cast_type) { + (DataType::Struct(_source_fields), DataType::Struct(target_fields)) => { + datafusion_common::nested_struct::cast_struct_array_by_name( + array, + target_fields, + cast_options, + ) + } + _ => { + ensure_date_array_timestamp_bounds(array, cast_type)?; + Ok(kernels::cast::cast_with_options( + array, + cast_type, + cast_options, + )?) + } + } +} + fn ensure_date_array_timestamp_bounds( array: &ArrayRef, cast_type: &DataType, @@ -378,8 +413,8 @@ impl fmt::Display for ColumnarValue { mod tests { use super::*; use arrow::{ - array::{Date64Array, Int32Array}, - datatypes::TimeUnit, + array::{Date64Array, Int32Array, StructArray}, + datatypes::{Field, Fields, TimeUnit}, }; #[test] @@ -553,6 +588,102 @@ mod tests { ); } + #[test] + fn cast_struct_by_field_name() { + let source_fields = Fields::from(vec![ + Field::new("b", DataType::Int32, true), + Field::new("a", DataType::Int32, true), + ]); + + let target_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let struct_array = StructArray::new( + source_fields, + vec![ + Arc::new(Int32Array::from(vec![Some(3)])), + Arc::new(Int32Array::from(vec![Some(4)])), + ], + None, + ); + + let value = ColumnarValue::Array(Arc::new(struct_array)); + let casted = value + .cast_to(&DataType::Struct(target_fields.clone()), None) + .expect("struct cast should succeed"); + + let ColumnarValue::Array(arr) = casted else { + panic!("expected array after cast"); + }; + + let struct_array = arr + .as_any() + .downcast_ref::() + .expect("expected StructArray"); + + let field_a = struct_array + .column_by_name("a") + .expect("expected field a in cast result"); + let field_b = struct_array + .column_by_name("b") + .expect("expected field b in cast result"); + + assert_eq!( + field_a + .as_any() + .downcast_ref::() + .expect("expected Int32 array") + .value(0), + 4 + ); + assert_eq!( + field_b + .as_any() + .downcast_ref::() + .expect("expected Int32 array") + .value(0), + 3 + ); + } + + #[test] + fn cast_struct_missing_field_inserts_nulls() { + let source_fields = Fields::from(vec![Field::new("a", DataType::Int32, true)]); + + let target_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let struct_array = StructArray::new( + source_fields, + vec![Arc::new(Int32Array::from(vec![Some(5)]))], + None, + ); + + let value = ColumnarValue::Array(Arc::new(struct_array)); + let casted = value + .cast_to(&DataType::Struct(target_fields.clone()), None) + .expect("struct cast should succeed"); + + let ColumnarValue::Array(arr) = casted else { + panic!("expected array after cast"); + }; + + let struct_array = arr + .as_any() + .downcast_ref::() + .expect("expected StructArray"); + + let field_b = struct_array + .column_by_name("b") + .expect("expected missing field to be added"); + + assert!(field_b.is_null(0)); + } + #[test] fn cast_date64_array_to_timestamp_overflow() { let overflow_value = i64::MAX / 1_000_000 + 1; diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index c9b39eacefc6a..f64efe9f474dd 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -17,6 +17,7 @@ //! Coercion rules for matching argument types for binary operators +use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; @@ -1236,30 +1237,84 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { (Struct(lhs_fields), Struct(rhs_fields)) => { + // Field count must match for coercion if lhs_fields.len() != rhs_fields.len() { return None; } - let coerced_types = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()) - .map(|(lhs, rhs)| comparison_coercion(lhs.data_type(), rhs.data_type())) - .collect::>>()?; - - // preserve the field name and nullability - let orig_fields = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()); + // If the two structs have exactly the same set of field names (possibly in + // different order), prefer name-based coercion. Otherwise fall back to + // positional coercion which preserves backward compatibility. + if fields_have_same_names(lhs_fields, rhs_fields) { + return coerce_struct_by_name(lhs_fields, rhs_fields); + } - let fields: Vec = coerced_types - .into_iter() - .zip(orig_fields) - .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs)) - .collect(); - Some(Struct(fields.into())) + coerce_struct_by_position(lhs_fields, rhs_fields) } _ => None, } } +/// Return true if every left-field name exists in the right fields (and lengths are equal). +fn fields_have_same_names(lhs_fields: &Fields, rhs_fields: &Fields) -> bool { + let rhs_names: HashSet<&str> = rhs_fields.iter().map(|f| f.name().as_str()).collect(); + lhs_fields + .iter() + .all(|lf| rhs_names.contains(lf.name().as_str())) +} + +/// Coerce two structs by matching fields by name. Assumes the name-sets match. +fn coerce_struct_by_name(lhs_fields: &Fields, rhs_fields: &Fields) -> Option { + use arrow::datatypes::DataType::*; + + let rhs_by_name: HashMap<&str, &FieldRef> = + rhs_fields.iter().map(|f| (f.name().as_str(), f)).collect(); + + let mut coerced: Vec = Vec::with_capacity(lhs_fields.len()); + + for lhs in lhs_fields.iter() { + let rhs = rhs_by_name.get(lhs.name().as_str()).unwrap(); // safe: caller ensured names match + let coerced_type = comparison_coercion(lhs.data_type(), rhs.data_type())?; + let is_nullable = lhs.is_nullable() || rhs.is_nullable(); + coerced.push(Arc::new(Field::new( + lhs.name().clone(), + coerced_type, + is_nullable, + ))); + } + + Some(Struct(coerced.into())) +} + +/// Coerce two structs positionally (left-to-right). This preserves field names from +/// the left struct and uses the combined nullability. +fn coerce_struct_by_position( + lhs_fields: &Fields, + rhs_fields: &Fields, +) -> Option { + use arrow::datatypes::DataType::*; + + // First coerce individual types; fail early if any pair cannot be coerced. + let coerced_types: Vec = lhs_fields + .iter() + .zip(rhs_fields.iter()) + .map(|(l, r)| comparison_coercion(l.data_type(), r.data_type())) + .collect::>>()?; + + // Build final fields preserving left-side names and combined nullability. + let orig_pairs = lhs_fields.iter().zip(rhs_fields.iter()); + let fields: Vec = coerced_types + .into_iter() + .zip(orig_pairs) + .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs)) + .collect(); + + Some(Struct(fields.into())) +} + /// returns the result of coercing two fields to a common type fn coerce_fields(common_type: DataType, lhs: &FieldRef, rhs: &FieldRef) -> FieldRef { let is_nullable = lhs.is_nullable() || rhs.is_nullable(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 854e907d68b1a..c0c043901c89c 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -659,7 +659,16 @@ impl ExprSchemable for Expr { // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? - if can_cast_types(&this_type, cast_to_type) { + // Special handling for struct-to-struct casts with name-based field matching + let can_cast = match (&this_type, cast_to_type) { + (DataType::Struct(_), DataType::Struct(_)) => { + // Always allow struct-to-struct casts; field matching happens at runtime + true + } + _ => can_cast_types(&this_type, cast_to_type), + }; + + if can_cast { match self { Expr::ScalarSubquery(subquery) => { Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 8740ab072a1f5..380a08c515319 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -409,6 +409,9 @@ impl Optimizer { } // OptimizerRule was unsuccessful, but skipped failed rules is off, return error (Err(e), None) => { + if matches!(e, DataFusionError::Plan(_)) { + return Err(e); + } return Err(e.context(format!( "Optimizer rule '{}' failed", rule.name() @@ -499,11 +502,7 @@ mod tests { schema: Arc::new(DFSchema::empty()), }); let err = opt.optimize(plan, &config, &observe).unwrap_err(); - assert_eq!( - "Optimizer rule 'bad rule' failed\ncaused by\n\ - Error during planning: rule failed", - err.strip_backtrace() - ); + assert_eq!("Error during planning: rule failed", err.strip_backtrace()); } #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 6d62fbc38574d..8cd2bbcf7788a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -38,8 +38,8 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_expr::{ - BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, and, - binary::BinaryTypeCoercer, lit, or, + BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, Volatility, + and, binary::BinaryTypeCoercer, lit, or, }; use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; @@ -641,6 +641,20 @@ impl ConstEvaluator { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } + Expr::Cast(Cast { expr, data_type }) + | Expr::TryCast(TryCast { expr, data_type }) => { + if let ( + Ok(DataType::Struct(source_fields)), + DataType::Struct(target_fields), + ) = (expr.get_type(&DFSchema::empty()), data_type) + { + // Don't const-fold struct casts with different field counts + if source_fields.len() != target_fields.len() { + return false; + } + } + true + } Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) @@ -659,8 +673,6 @@ impl ConstEvaluator { | Expr::Like { .. } | Expr::SimilarTo { .. } | Expr::Case(_) - | Expr::Cast { .. } - | Expr::TryCast { .. } | Expr::InList { .. } => true, } } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index bd5c63a69979f..ba9bb56cd94d1 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -237,6 +237,11 @@ pub fn cast_with_options( Ok(Arc::clone(&expr)) } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) + } else if matches!((&expr_type, &cast_type), (Struct(_), Struct(_))) { + // Allow struct-to-struct casts even if Arrow's can_cast_types rejects them + // (e.g., field count mismatches). These will be handled by name-based casting + // at execution time via ColumnarValue::cast_to + Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") } diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 481dde5be9f5c..8e0ee08d994a8 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -383,8 +383,10 @@ SELECT column2, column3, column4 FROM t; ---- {foo: a, xxx: b} {xxx: c, foo: d} {xxx: e} -# coerce structs with different field orders, -# should keep the same field values +# coerce structs with different field orders +# With name-based struct coercion, matching fields by name: +# column2={foo:a, xxx:b} unified with column3={xxx:c, foo:d} +# Result uses the THEN branch's field order (when executed): {xxx: b, foo: a} query ? SELECT case @@ -396,6 +398,7 @@ FROM t; {xxx: b, foo: a} # coerce structs with different field orders +# When ELSE branch executes, uses its field order: {xxx: c, foo: d} query ? SELECT case @@ -406,8 +409,9 @@ FROM t; ---- {xxx: c, foo: d} -# coerce structs with subset of fields -query error Failed to coerce then +# coerce structs with subset of fields - field count mismatch causes type coercion failure +# column3 has 2 fields but column4 has only 1 field +query error DataFusion error: type_coercion\ncaused by\nError during planning: Failed to coerce then .* and else .* to common types in CASE WHEN expression SELECT case when column1 > 0 then column3 diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index a91a5e7f870a9..5c9d91bbe1978 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -492,18 +492,6 @@ Struct("r": Utf8, "c": Float64) statement ok drop table t; -statement ok -create table t as values({r: 'a', c: 1}), ({c: 2.3, r: 'b'}); - -query ? -select * from t; ----- -{c: 1.0, r: a} -{c: 2.3, r: b} - -statement ok -drop table t; - ################################## ## Test Coalesce with Struct ################################## @@ -562,24 +550,12 @@ Struct("a": Float32, "b": Utf8View) statement ok drop table t; -# row() with incorrect order +# row() with incorrect order - row() is positional, not name-based statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Arrow error: Cast error: Cannot cast string 'blue' to value of Float32 type create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values (row('red', 1), row(2.3, 'blue')), (row('purple', 1), row('green', 2.3)); -# out of order struct literal -statement ok -create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'}); - -query ? -select * from t; ----- -{r: a, c: 1} -{r: b, c: 2} - -statement ok -drop table t; ################################## ## Test Array of Struct @@ -590,11 +566,6 @@ select [{r: 'a', c: 1}, {r: 'b', c: 2}]; ---- [{r: a, c: 1}, {r: b, c: 2}] -# Create a list of struct with different field types -query ? -select [{r: 'a', c: 1}, {c: 2, r: 'b'}]; ----- -[{c: 1, r: a}, {c: 2, r: b}] statement ok create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values (row('a', 1), row('b', 2.3)); @@ -607,18 +578,6 @@ List(Struct("r": Utf8View, "c": Float32)) statement ok drop table t; -# create table with different struct type is fine -statement ok -create table t(a struct(r varchar, c int), b struct(c float, r varchar)) as values (row('a', 1), row(2.3, 'b')); - -# create array with different struct type should be cast -query T -select arrow_typeof([a, b]) from t; ----- -List(Struct("c": Float32, "r": Utf8View)) - -statement ok -drop table t; statement ok create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values (row('a', 1, 2.3), row('b', 2.3, 2)); @@ -845,3 +804,96 @@ NULL statement ok drop table nullable_parent_test; + +# Test struct casting with field reordering - string fields +query ? +SELECT CAST({b: 'b_value', a: 'a_value'} AS STRUCT(a VARCHAR, b VARCHAR)); +---- +{a: a_value, b: b_value} + +# Test struct casting with field reordering - integer fields +query ? +SELECT CAST({b: 3, a: 4} AS STRUCT(a INT, b INT)); +---- +{a: 4, b: 3} + +# Test with type casting AND field reordering +query ? +SELECT CAST({b: 3, a: 4} AS STRUCT(a BIGINT, b INT)); +---- +{a: 4, b: 3} + +# Test positional casting when there is no name overlap +query ? +SELECT CAST(struct(1, 'x') AS STRUCT(a INT, b VARCHAR)); +---- +{a: 1, b: x} + +# Test with missing field - should insert nulls +query ? +SELECT CAST({a: 1} AS STRUCT(a INT, b INT)); +---- +{a: 1, b: NULL} + +# Test with extra source field - should be ignored +query ? +SELECT CAST({a: 1, b: 2, extra: 3} AS STRUCT(a INT, b INT)); +---- +{a: 1, b: 2} + +# Test no overlap with mismatched field count +query error DataFusion error: (Plan error|Error during planning): Cannot cast struct with 3 fields to 2 fields without name overlap; positional mapping is ambiguous +SELECT CAST(struct(1, 'x', 'y') AS STRUCT(a INT, b VARCHAR)); + +# Test nested struct with field reordering +query ? +SELECT CAST( + {inner: {y: 2, x: 1}} + AS STRUCT(inner STRUCT(x INT, y INT)) +); +---- +{inner: {x: 1, y: 2}} + +# Test field reordering with table data +statement ok +CREATE TABLE struct_reorder_test ( + data STRUCT(b INT, a VARCHAR) +) AS VALUES + (struct(100, 'first')), + (struct(200, 'second')), + (struct(300, 'third')) +; + +query ? +SELECT CAST(data AS STRUCT(a VARCHAR, b INT)) AS casted_data FROM struct_reorder_test ORDER BY data['b']; +---- +{a: first, b: 100} +{a: second, b: 200} +{a: third, b: 300} + +statement ok +drop table struct_reorder_test; + +# Test casting struct with multiple levels of nesting and reordering +query ? +SELECT CAST( + {level1: {z: 100, y: 'inner', x: 1}} + AS STRUCT(level1 STRUCT(x INT, y VARCHAR, z INT)) +); +---- +{level1: {x: 1, y: inner, z: 100}} + +# Test field reordering with nulls in source +query ? +SELECT CAST( + {b: NULL::INT, a: 42} + AS STRUCT(a INT, b INT) +); +---- +{a: 42, b: NULL} + +# Test casting preserves struct-level nulls +query ? +SELECT CAST(NULL::STRUCT(b INT, a INT) AS STRUCT(a INT, b INT)); +---- +NULL