diff --git a/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java b/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java index 987ffcd23..2bdfb00e3 100644 --- a/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java +++ b/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java @@ -25,6 +25,16 @@ public Type getType() { public abstract Expression.AggregationInvocation invocation(); + /** + * Validates that variadic arguments satisfy the parameter consistency requirement. When + * CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When + * INCONSISTENT, arguments can have different types. + */ + @Value.Check + protected void check() { + VariadicParameterConsistencyValidator.validate(declaration(), arguments()); + } + public static ImmutableAggregateFunctionInvocation.Builder builder() { return ImmutableAggregateFunctionInvocation.builder(); } diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 42c3c5118..f34d94495 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -799,6 +799,16 @@ public Type getType() { return outputType(); } + /** + * Validates that variadic arguments satisfy the parameter consistency requirement. When + * CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When + * INCONSISTENT, arguments can have different types. + */ + @Value.Check + protected void check() { + VariadicParameterConsistencyValidator.validate(declaration(), arguments()); + } + public static ImmutableExpression.ScalarFunctionInvocation.Builder builder() { return ImmutableExpression.ScalarFunctionInvocation.builder(); } @@ -840,6 +850,16 @@ public Type getType() { public abstract AggregationInvocation invocation(); + /** + * Validates that variadic arguments satisfy the parameter consistency requirement. When + * CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When + * INCONSISTENT, arguments can have different types. + */ + @Value.Check + protected void check() { + VariadicParameterConsistencyValidator.validate(declaration(), arguments()); + } + public static ImmutableExpression.WindowFunctionInvocation.Builder builder() { return ImmutableExpression.WindowFunctionInvocation.builder(); } diff --git a/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java b/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java new file mode 100644 index 000000000..8619d9274 --- /dev/null +++ b/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java @@ -0,0 +1,99 @@ +package io.substrait.expression; + +import io.substrait.extension.SimpleExtension; +import io.substrait.type.Type; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Helper class for validating variadic parameter consistency in function invocations. Validates + * that when parameterConsistency is CONSISTENT, all variadic arguments have the same type (ignoring + * nullability). + */ +public class VariadicParameterConsistencyValidator { + + /** + * Validates that variadic arguments satisfy the parameter consistency requirement. When + * CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When + * INCONSISTENT, arguments can have different types. + * + * @param func the function declaration + * @param arguments the function arguments to validate + * @throws AssertionError if validation fails + */ + public static void validate(SimpleExtension.Function func, List arguments) { + Optional variadic = func.variadic(); + if (!variadic.isPresent()) { + return; + } + + SimpleExtension.VariadicBehavior variadicBehavior = variadic.get(); + if (variadicBehavior.parameterConsistency() + != SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) { + // INCONSISTENT allows different types, so validation passes + // TODO: Even when parameterConsistency is INCONSISTENT, there can be implicit constraints + // across variadic parameters due to type parameters. For example, consider a function with: + // args: [value: "decimal", variadic: {min: 1, parameterConsistency: INCONSISTENT}] + // return: "decimal<38,S>" + // In this case, while the precision P can vary across variadic arguments, the scale S must + // be consistent across all variadic arguments (since it's used in the return type). The + // current implementation doesn't validate these type parameter constraints. According to + // the spec: "Each argument can be any possible concrete type afforded by the bounds of any + // parameter defined in the arguments specification." This means we need to check that type + // parameters that appear in the return type (or are otherwise constrained) are consistent + // across variadic arguments, even when parameterConsistency is INCONSISTENT. + return; + } + + // Extract types from arguments (only Expression and Type have types, EnumArg doesn't) + List argumentTypes = + arguments.stream() + .filter(arg -> arg instanceof Expression || arg instanceof Type) + .map( + arg -> { + if (arg instanceof Expression) { + return ((Expression) arg).getType(); + } else { + return (Type) arg; + } + }) + .collect(Collectors.toList()); + + // Count how many Expression/Type arguments are in the fixed arguments (before variadic) + // Note: func.args() includes all argument types (Expression, Type, EnumArg), but we only + // care about Expression/Type arguments for type consistency checking + int fixedTypeArgCount = 0; + for (int i = 0; i < func.args().size() && i < arguments.size(); i++) { + FunctionArg arg = arguments.get(i); + if (arg instanceof Expression || arg instanceof Type) { + fixedTypeArgCount++; + } + } + + if (argumentTypes.size() <= fixedTypeArgCount) { + // No variadic arguments, validation passes + return; + } + + // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) + // Compare all variadic arguments to the first one for more informative error messages + // Variadic arguments start immediately after the fixed arguments + int firstVariadicArgIdx = fixedTypeArgCount; + if (firstVariadicArgIdx >= argumentTypes.size()) { + // Not enough variadic arguments provided, validation passes + return; + } + Type firstVariadicType = argumentTypes.get(firstVariadicArgIdx); + for (int i = firstVariadicArgIdx + 1; i < argumentTypes.size(); i++) { + Type currentType = argumentTypes.get(i); + if (!firstVariadicType.equalsIgnoringNullability(currentType)) { + throw new AssertionError( + String.format( + "Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. " + + "Argument at index %d has type %s but argument at index %d has type %s", + firstVariadicArgIdx, firstVariadicType, i, currentType)); + } + } + } +} diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 39d7c45e0..ab9058e77 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -235,6 +235,7 @@ enum ParameterConsistency { INCONSISTENT } + @Value.Default default ParameterConsistency parameterConsistency() { return ParameterConsistency.CONSISTENT; } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index aaf97aa12..15606f5a3 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -25,6 +25,16 @@ default R accept( return fnArgVisitor.visitType(fnDef, argIdx, this, context); } + /** + * Compares this type with another type, ignoring nullability differences. + * + * @param other the type to compare with + * @return true if the types are equal when both are treated as nullable + */ + default boolean equalsIgnoringNullability(Type other) { + return TypeCreator.asNullable(this).equals(TypeCreator.asNullable(other)); + } + @Value.Immutable abstract class Bool implements Type { public static ImmutableType.Bool.Builder builder() { diff --git a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java new file mode 100644 index 000000000..61d74f395 --- /dev/null +++ b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java @@ -0,0 +1,277 @@ +package io.substrait.expression; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.extension.ImmutableSimpleExtension; +import io.substrait.extension.SimpleExtension; +import io.substrait.function.ParameterizedType; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** Tests for variadic parameter consistency validation in Expression. */ +class VariadicParameterConsistencyTest { + + private static final TypeCreator R = TypeCreator.of(false); + + /** + * Helper method to create a ScalarFunctionInvocation and test if it validates correctly. The + * validation happens in the @Value.Check method when the Expression is built. + */ + private Expression.ScalarFunctionInvocation createScalarFunctionInvocation( + List args, + SimpleExtension.VariadicBehavior variadic, + List arguments) { + SimpleExtension.ScalarFunctionVariant declaration = + ImmutableSimpleExtension.ScalarFunctionVariant.builder() + .urn("extension:test:variadic") + .name("test_func") + .args(args) + .variadic( + variadic != null ? java.util.Optional.of(variadic) : java.util.Optional.empty()) + .returnType(R.I64) + .options(java.util.Collections.emptyMap()) + .build(); + + return Expression.ScalarFunctionInvocation.builder() + .declaration(declaration) + .arguments(arguments) + .outputType(R.I64) + .options(java.util.Collections.emptyList()) + .build(); + } + + @Test + void testConsistentVariadicWithSameTypes() { + // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build())), + "Consistent variadic with same types should pass"); + + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).nullable(true).build(), + Expression.I64Literal.builder().value(3).nullable(true).build(), + Expression.I64Literal.builder().value(4).build())), + "Consistent variadic with same types but different nullability should pass"); + } + + @Test + void testConsistentVariadicWithDifferentTypes() { + // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + assertThrows( + AssertionError.class, + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.FP64Literal.builder().value(3.0).build())), + "Consistent variadic with different types should fail"); + + assertThrows( + AssertionError.class, + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I32Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build())), + "Consistent variadic with different types should fail"); + } + + @Test + void testInconsistentVariadicWithDifferentTypes() { + // Function: test_func(i64, any...) with INCONSISTENT parameterConsistency + // When INCONSISTENT, each variadic argument can be a different type, but they all need to + // match the parameterized type constraint (any in this case) + ParameterizedType anyType = + ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); + + List args = + List.of( + SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build(), + SimpleExtension.ValueArgument.builder().value(anyType).name("variadic_arg").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency( + SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT) + .build(); + + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.FP64Literal.builder().value(3.0).build())), + "Inconsistent variadic with different types should pass"); + + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I32Literal.builder().value(2).build(), + Expression.FP64Literal.builder().value(3.0).build(), + Expression.StrLiteral.builder().value("test").build())), + "Inconsistent variadic with different types should pass"); + } + + @Test + void testConsistentVariadicWithWildcardType() { + // Function: test_func(any, any...) with CONSISTENT parameterConsistency + // The variadic arguments should all have the same concrete type + ParameterizedType anyType = + ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); + + List args = + List.of(SimpleExtension.ValueArgument.builder().value(anyType).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build())), + "Consistent variadic with wildcard type and same concrete types should pass"); + + assertThrows( + AssertionError.class, + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.FP64Literal.builder().value(3.0).build())), + "Consistent variadic with wildcard type and different concrete types should fail"); + } + + @Test + void testConsistentVariadicWithMinGreaterThanOne() { + // Function: test_func(i64, i64...) with CONSISTENT and min=2 + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(2) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build(), + Expression.I64Literal.builder().value(4).build())), + "Consistent variadic with min=2 and same types should pass"); + + assertThrows( + AssertionError.class, + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build(), + Expression.FP64Literal.builder().value(4.0).build())), + "Consistent variadic with min=2 and different types should fail"); + } + + @Test + void testNoVariadicBehavior() { + // Function: test_func(i64) with no variadic behavior + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, null, List.of(Expression.I64Literal.builder().value(1).build())), + "No variadic behavior should always pass"); + } + + @Test + void testConsistentVariadicWithNullableTypes() { + // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).nullable(true).build(), + Expression.I64Literal.builder().value(3).build(), + Expression.I64Literal.builder().value(4).nullable(true).build())), + "Consistent variadic with same types but different nullability should pass"); + } +} diff --git a/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java b/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java new file mode 100644 index 000000000..e4e9d2a7a --- /dev/null +++ b/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java @@ -0,0 +1,32 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +/** Tests for VariadicBehavior, particularly parameterConsistency loading from YAML. */ +class VariadicBehaviorTest { + + @Test + void testParameterConsistencyLoading() { + String yamlContent = + "urn: extension:test:example\n" + + "scalar_functions:\n" + + " - name: test_func\n" + + " impls:\n" + + " - args:\n" + + " - name: arg1\n" + + " value: string\n" + + " variadic:\n" + + " min: 1\n" + + " parameterConsistency: CONSISTENT\n" + + " return: string\n"; + + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("test://example", yamlContent); + + assertEquals( + SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT, + collection.scalarFunctions().get(0).variadic().get().parameterConsistency()); + } +}