From e54b4f2072e52eab07b34a36332f8e732be09dca Mon Sep 17 00:00:00 2001 From: Adam Alani Date: Thu, 27 Nov 2025 10:47:33 +0100 Subject: [PATCH 1/7] added parameterConsistency --- .../substrait/extension/SimpleExtension.java | 1 + .../extension/TypeExtensionTest.java | 23 ++ .../isthmus/expression/FunctionConverter.java | 49 ++- .../VariadicParameterConsistencyTest.java | 293 ++++++++++++++++++ 4 files changed, 359 insertions(+), 7 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 39d7c45e0..ab9058e77 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -235,6 +235,7 @@ enum ParameterConsistency { INCONSISTENT } + @Value.Default default ParameterConsistency parameterConsistency() { return ParameterConsistency.CONSISTENT; } diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index abb4008d5..fcbeeb391 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -97,4 +97,27 @@ void roundtripNumberedAnyTypes() { Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } + + @Test + public void testParameterConsistencyLoading() { + String yamlContent = + "urn: extension:test:example\n" + + "scalar_functions:\n" + + " - name: test_func\n" + + " impls:\n" + + " - args:\n" + + " - name: arg1\n" + + " value: string\n" + + " variadic:\n" + + " min: 1\n" + + " parameterConsistency: CONSISTENT\n" + + " return: string\n"; + + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("test://example", yamlContent); + + assertEquals( + SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT, + collection.scalarFunctions().get(0).variadic().get().parameterConsistency()); + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index d738ef157..ffa550c97 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -43,6 +43,7 @@ import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlOperator; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -244,7 +245,8 @@ private Optional signatureMatch(List inputTypes, Type outputType) { // Make sure that arguments & return are within bounds and match the types if (function.returnType() instanceof ParameterizedType && isMatch(outputType, (ParameterizedType) function.returnType()) - && inputTypesMatchDefinedArguments(inputTypes, args)) { + && inputTypesMatchDefinedArguments( + inputTypes, args, function.variadic().orElse(null))) { return Optional.of(function); } } @@ -256,17 +258,20 @@ && inputTypesMatchDefinedArguments(inputTypes, args)) { * Checks to see if the given input types satisfy the function arguments given. Checks that * *
    - *
  • Variadic arguments all have the same input type + *
  • Variadic arguments all have the same input type (when parameterConsistency is CONSISTENT) *
  • Matched wildcard arguments (i.e.`any`, `any1`, `any2`, etc) all have the same input * type *
* * @param inputTypes input types to check against arguments * @param args expected arguments as defined in a {@link SimpleExtension.Function} + * @param variadic the variadic behavior to check for consistency, or null if not variadic * @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise */ - private boolean inputTypesMatchDefinedArguments( - List inputTypes, List args) { + boolean inputTypesMatchDefinedArguments( + List inputTypes, + List args, + SimpleExtension.@Nullable VariadicBehavior variadic) { Map> wildcardToType = new HashMap<>(); for (int i = 0; i < inputTypes.size(); i++) { @@ -292,9 +297,39 @@ private boolean inputTypesMatchDefinedArguments( } // If all the types match, check if the wildcard types are compatible. - // TODO: Determine if non-enumerated wildcard types (i.e. `any` as opposed to `any1`) need to - // have the same type. - return wildcardToType.values().stream().allMatch(s -> s.size() == 1); + // When parameterConsistency is INCONSISTENT, wildcard types can differ. + // When parameterConsistency is CONSISTENT (or not variadic), wildcard types must be the same. + boolean allowInconsistentWildcards = + variadic != null + && variadic.parameterConsistency() + == SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT; + if (!allowInconsistentWildcards + && !wildcardToType.values().stream().allMatch(s -> s.size() == 1)) { + return false; + } + + // Validate variadic argument consistency + if (variadic != null + && variadic.parameterConsistency() + == SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT + && inputTypes.size() > args.size()) { + // Check that all variadic arguments have the same type (ignoring nullability) + int firstVariadicArgIdx = Math.max(variadic.getMin() - 1, 0); + // Compare consecutive arguments starting from firstVariadicArgIdx + for (int i = firstVariadicArgIdx; i < inputTypes.size() - 1; i++) { + Type currentType = inputTypes.get(i); + Type nextType = inputTypes.get(i + 1); + // Compare types ignoring nullability - check both directions to ensure types match + boolean typesMatch = + currentType.accept(new IgnoreNullableAndParameters(nextType)) + || nextType.accept(new IgnoreNullableAndParameters(currentType)); + if (!typesMatch) { + return false; + } + } + } + + return true; } /** diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java new file mode 100644 index 000000000..3955e2e8c --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java @@ -0,0 +1,293 @@ +package io.substrait.isthmus.expression; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.extension.ImmutableSimpleExtension; +import io.substrait.extension.SimpleExtension; +import io.substrait.function.ParameterizedType; +import io.substrait.isthmus.PlanTestBase; +import io.substrait.isthmus.TypeConverter; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** Tests for variadic parameter consistency validation in FunctionConverter. */ +class VariadicParameterConsistencyTest extends PlanTestBase { + + private static final TypeCreator R = TypeCreator.of(false); + private static final TypeCreator N = TypeCreator.of(true); + + /** + * Helper method to test if input types match defined arguments with variadic behavior. Creates a + * minimal FunctionConverter and FunctionFinder to access the package-private method. + */ + private boolean testInputTypesMatch( + List inputTypes, + List args, + SimpleExtension.VariadicBehavior variadic) { + // Create a minimal test function to use with FunctionConverter + SimpleExtension.ScalarFunctionVariant testFunction = + ImmutableSimpleExtension.ScalarFunctionVariant.builder() + .urn("extension:test:variadic") + .name("test_func") + .args(args) + .variadic( + variadic != null ? java.util.Optional.of(variadic) : java.util.Optional.empty()) + .returnType(R.I64) + .options(java.util.Collections.emptyMap()) + .build(); + + // Create a ScalarFunctionConverter with our test function + ScalarFunctionConverter converter = + new ScalarFunctionConverter( + List.of(testFunction), + java.util.Collections.emptyList(), + typeFactory, + TypeConverter.DEFAULT); + + // Create a test helper that can access the package-private method + TestHelper testHelper = new TestHelper(converter, testFunction); + return testHelper.testInputTypesMatch(inputTypes, args, variadic); + } + + /** + * Test helper class that extends ScalarFunctionConverter to access package-private + * inputTypesMatchDefinedArguments method. + */ + private static class TestHelper extends ScalarFunctionConverter { + private final SimpleExtension.ScalarFunctionVariant testFunction; + + TestHelper(ScalarFunctionConverter parent, SimpleExtension.ScalarFunctionVariant testFunction) { + super( + java.util.Collections.emptyList(), + java.util.Collections.emptyList(), + parent.typeFactory, + parent.typeConverter); + this.testFunction = testFunction; + } + + boolean testInputTypesMatch( + List inputTypes, + List args, + SimpleExtension.VariadicBehavior variadic) { + // Create a minimal FunctionFinder to access the package-private method + // We need a SqlOperator, so create a dummy one + org.apache.calcite.sql.SqlFunction dummyOperator = + new org.apache.calcite.sql.SqlFunction( + "test", + org.apache.calcite.sql.SqlKind.OTHER_FUNCTION, + org.apache.calcite.sql.type.ReturnTypes.explicit( + org.apache.calcite.sql.type.SqlTypeName.BIGINT), + null, + null, + org.apache.calcite.sql.SqlFunctionCategory.USER_DEFINED_FUNCTION); + + // Create a FunctionFinder with the test function so it can compute argRange properly + // FunctionFinder is a protected inner class, so we can access it from a subclass + FunctionFinder finder = new FunctionFinder("test_func", dummyOperator, List.of(testFunction)); + + // Access the package-private method + return finder.inputTypesMatchDefinedArguments(inputTypes, args, variadic); + } + } + + @Test + void testConsistentVariadicWithSameTypes() { + // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency + List args = + List.of( + SimpleExtension.ValueArgument.builder() + .value(R.I64) + .name("arg1") + .build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // All variadic arguments are i64 - should pass + assertTrue( + testInputTypesMatch( + List.of(R.I64, R.I64, R.I64), args, variadic), + "Consistent variadic with same types should match"); + + // All variadic arguments are i64 (with different nullability) - should pass + assertTrue( + testInputTypesMatch( + List.of(R.I64, R.I64, N.I64, R.I64), args, variadic), + "Consistent variadic with same types but different nullability should match"); + + } + + @Test + void testConsistentVariadicWithDifferentTypes() { + // Function: test_func(i64, any...) with CONSISTENT parameterConsistency + List args = + List.of( + SimpleExtension.ValueArgument.builder() + .value(R.I64) + .name("arg1") + .build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // Variadic arguments have different types - should fail + assertFalse( + testInputTypesMatch( + List.of(R.I64, R.I64, R.FP64), args, variadic), + "Consistent variadic with different types should not match"); + + assertFalse( + testInputTypesMatch( + List.of(R.I64, R.I32, R.I64), args, variadic), + "Consistent variadic with different types should not match"); + } + + @Test + void testInconsistentVariadicWithDifferentTypes() { + // Function: test_func(i64, any...) with INCONSISTENT parameterConsistency + // When INCONSISTENT, each variadic argument can be a different type, but they all need to + // match the parameterized type constraint (any in this case) + ParameterizedType anyType = + ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); + + List args = + List.of( + SimpleExtension.ValueArgument.builder() + .value(R.I64) + .name("arg1") + .build(), + SimpleExtension.ValueArgument.builder() + .value(anyType) + .name("variadic_arg") + .build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency( + SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT) + .build(); + + // Variadic arguments have different types - should pass with INCONSISTENT and wildcard type + assertTrue( + testInputTypesMatch( + List.of(R.I64, R.I64, R.FP64), args, variadic), + "Inconsistent variadic with different types should match when using wildcard type"); + + assertTrue( + testInputTypesMatch( + List.of(R.I64, R.I32, R.FP64, R.STRING), args, variadic), + "Inconsistent variadic with different types should match when using wildcard type"); + } + + @Test + void testConsistentVariadicWithWildcardType() { + // Function: test_func(any, any...) with CONSISTENT parameterConsistency + // The variadic arguments should all have the same concrete type + ParameterizedType anyType = + ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); + + List args = + List.of( + SimpleExtension.ValueArgument.builder() + .value(anyType) + .name("arg1") + .build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // All variadic arguments are i64 - should pass + assertTrue( + testInputTypesMatch( + List.of(R.I64, R.I64, R.I64), args, variadic), + "Consistent variadic with wildcard type and same concrete types should match"); + + // Variadic arguments have different types - should fail + assertFalse( + testInputTypesMatch( + List.of(R.I64, R.I64, R.FP64), args, variadic), + "Consistent variadic with wildcard type and different concrete types should not match"); + } + + @Test + void testConsistentVariadicWithMinGreaterThanOne() { + // Function: test_func(i64, i64...) with CONSISTENT and min=2 + List args = + List.of( + SimpleExtension.ValueArgument.builder() + .value(R.I64) + .name("arg1") + .build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(2) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // All variadic arguments are i64 - should pass + assertTrue( + testInputTypesMatch( + List.of(R.I64, R.I64, R.I64, R.I64), args, variadic), + "Consistent variadic with min=2 and same types should match"); + + // Variadic arguments have different types - should fail + assertFalse( + testInputTypesMatch( + List.of(R.I64, R.I64, R.I64, R.FP64), args, variadic), + "Consistent variadic with min=2 and different types should not match"); + } + + @Test + void testNoVariadicBehavior() { + // Function: test_func(i64) with no variadic behavior + List args = + List.of( + SimpleExtension.ValueArgument.builder() + .value(R.I64) + .name("arg1") + .build()); + + // No variadic behavior - should pass regardless of consistency + assertTrue( + testInputTypesMatch(List.of(R.I64), args, null), + "No variadic behavior should always match"); + } + + @Test + void testConsistentVariadicWithNullableTypes() { + // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency + List args = + List.of( + SimpleExtension.ValueArgument.builder() + .value(R.I64) + .name("arg1") + .build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // Mix of nullable and non-nullable i64 - should pass (nullability is ignored) + assertTrue( + testInputTypesMatch( + List.of(R.I64, N.I64, R.I64, N.I64), args, variadic), + "Consistent variadic with same types but different nullability should match"); + } +} + From 033f093440baa019e00b3bc93e45981bbd59cda7 Mon Sep 17 00:00:00 2001 From: Adam Alani Date: Thu, 27 Nov 2025 10:59:41 +0100 Subject: [PATCH 2/7] fmt --- .../extension/TypeExtensionTest.java | 34 ++++---- .../isthmus/expression/FunctionConverter.java | 5 +- .../VariadicParameterConsistencyTest.java | 81 +++++-------------- 3 files changed, 40 insertions(+), 80 deletions(-) diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index fcbeeb391..e2dc352ba 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -101,23 +101,23 @@ void roundtripNumberedAnyTypes() { @Test public void testParameterConsistencyLoading() { String yamlContent = - "urn: extension:test:example\n" - + "scalar_functions:\n" - + " - name: test_func\n" - + " impls:\n" - + " - args:\n" - + " - name: arg1\n" - + " value: string\n" - + " variadic:\n" - + " min: 1\n" - + " parameterConsistency: CONSISTENT\n" - + " return: string\n"; + "urn: extension:test:example\n" + + "scalar_functions:\n" + + " - name: test_func\n" + + " impls:\n" + + " - args:\n" + + " - name: arg1\n" + + " value: string\n" + + " variadic:\n" + + " min: 1\n" + + " parameterConsistency: CONSISTENT\n" + + " return: string\n"; - SimpleExtension.ExtensionCollection collection = - SimpleExtension.load("test://example", yamlContent); + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("test://example", yamlContent); - assertEquals( - SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT, - collection.scalarFunctions().get(0).variadic().get().parameterConsistency()); - } + assertEquals( + SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT, + collection.scalarFunctions().get(0).variadic().get().parameterConsistency()); + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index ffa550c97..7c37eeabd 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -20,6 +20,7 @@ import io.substrait.isthmus.expression.FunctionMappings.TypeBasedResolver; import io.substrait.type.Type; import io.substrait.util.Util; + import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -37,6 +38,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; @@ -258,7 +260,8 @@ && inputTypesMatchDefinedArguments( * Checks to see if the given input types satisfy the function arguments given. Checks that * *
    - *
  • Variadic arguments all have the same input type (when parameterConsistency is CONSISTENT) + *
  • Variadic arguments all have the same input type (when parameterConsistency is + * CONSISTENT) *
  • Matched wildcard arguments (i.e.`any`, `any1`, `any2`, etc) all have the same input * type *
diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java index 3955e2e8c..38ff68cde 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java @@ -97,11 +97,7 @@ boolean testInputTypesMatch( void testConsistentVariadicWithSameTypes() { // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency List args = - List.of( - SimpleExtension.ValueArgument.builder() - .value(R.I64) - .name("arg1") - .build()); + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); SimpleExtension.VariadicBehavior variadic = ImmutableSimpleExtension.VariadicBehavior.builder() @@ -111,27 +107,20 @@ void testConsistentVariadicWithSameTypes() { // All variadic arguments are i64 - should pass assertTrue( - testInputTypesMatch( - List.of(R.I64, R.I64, R.I64), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I64, R.I64), args, variadic), "Consistent variadic with same types should match"); // All variadic arguments are i64 (with different nullability) - should pass assertTrue( - testInputTypesMatch( - List.of(R.I64, R.I64, N.I64, R.I64), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I64, N.I64, R.I64), args, variadic), "Consistent variadic with same types but different nullability should match"); - } @Test void testConsistentVariadicWithDifferentTypes() { // Function: test_func(i64, any...) with CONSISTENT parameterConsistency List args = - List.of( - SimpleExtension.ValueArgument.builder() - .value(R.I64) - .name("arg1") - .build()); + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); SimpleExtension.VariadicBehavior variadic = ImmutableSimpleExtension.VariadicBehavior.builder() @@ -141,13 +130,11 @@ void testConsistentVariadicWithDifferentTypes() { // Variadic arguments have different types - should fail assertFalse( - testInputTypesMatch( - List.of(R.I64, R.I64, R.FP64), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic), "Consistent variadic with different types should not match"); assertFalse( - testInputTypesMatch( - List.of(R.I64, R.I32, R.I64), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I32, R.I64), args, variadic), "Consistent variadic with different types should not match"); } @@ -161,14 +148,8 @@ void testInconsistentVariadicWithDifferentTypes() { List args = List.of( - SimpleExtension.ValueArgument.builder() - .value(R.I64) - .name("arg1") - .build(), - SimpleExtension.ValueArgument.builder() - .value(anyType) - .name("variadic_arg") - .build()); + SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build(), + SimpleExtension.ValueArgument.builder().value(anyType).name("variadic_arg").build()); SimpleExtension.VariadicBehavior variadic = ImmutableSimpleExtension.VariadicBehavior.builder() @@ -179,13 +160,11 @@ void testInconsistentVariadicWithDifferentTypes() { // Variadic arguments have different types - should pass with INCONSISTENT and wildcard type assertTrue( - testInputTypesMatch( - List.of(R.I64, R.I64, R.FP64), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic), "Inconsistent variadic with different types should match when using wildcard type"); assertTrue( - testInputTypesMatch( - List.of(R.I64, R.I32, R.FP64, R.STRING), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I32, R.FP64, R.STRING), args, variadic), "Inconsistent variadic with different types should match when using wildcard type"); } @@ -197,11 +176,7 @@ void testConsistentVariadicWithWildcardType() { ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); List args = - List.of( - SimpleExtension.ValueArgument.builder() - .value(anyType) - .name("arg1") - .build()); + List.of(SimpleExtension.ValueArgument.builder().value(anyType).name("arg1").build()); SimpleExtension.VariadicBehavior variadic = ImmutableSimpleExtension.VariadicBehavior.builder() @@ -211,14 +186,12 @@ void testConsistentVariadicWithWildcardType() { // All variadic arguments are i64 - should pass assertTrue( - testInputTypesMatch( - List.of(R.I64, R.I64, R.I64), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I64, R.I64), args, variadic), "Consistent variadic with wildcard type and same concrete types should match"); // Variadic arguments have different types - should fail assertFalse( - testInputTypesMatch( - List.of(R.I64, R.I64, R.FP64), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic), "Consistent variadic with wildcard type and different concrete types should not match"); } @@ -226,11 +199,7 @@ void testConsistentVariadicWithWildcardType() { void testConsistentVariadicWithMinGreaterThanOne() { // Function: test_func(i64, i64...) with CONSISTENT and min=2 List args = - List.of( - SimpleExtension.ValueArgument.builder() - .value(R.I64) - .name("arg1") - .build()); + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); SimpleExtension.VariadicBehavior variadic = ImmutableSimpleExtension.VariadicBehavior.builder() @@ -240,14 +209,12 @@ void testConsistentVariadicWithMinGreaterThanOne() { // All variadic arguments are i64 - should pass assertTrue( - testInputTypesMatch( - List.of(R.I64, R.I64, R.I64, R.I64), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I64, R.I64, R.I64), args, variadic), "Consistent variadic with min=2 and same types should match"); // Variadic arguments have different types - should fail assertFalse( - testInputTypesMatch( - List.of(R.I64, R.I64, R.I64, R.FP64), args, variadic), + testInputTypesMatch(List.of(R.I64, R.I64, R.I64, R.FP64), args, variadic), "Consistent variadic with min=2 and different types should not match"); } @@ -255,11 +222,7 @@ void testConsistentVariadicWithMinGreaterThanOne() { void testNoVariadicBehavior() { // Function: test_func(i64) with no variadic behavior List args = - List.of( - SimpleExtension.ValueArgument.builder() - .value(R.I64) - .name("arg1") - .build()); + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); // No variadic behavior - should pass regardless of consistency assertTrue( @@ -271,11 +234,7 @@ void testNoVariadicBehavior() { void testConsistentVariadicWithNullableTypes() { // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency List args = - List.of( - SimpleExtension.ValueArgument.builder() - .value(R.I64) - .name("arg1") - .build()); + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); SimpleExtension.VariadicBehavior variadic = ImmutableSimpleExtension.VariadicBehavior.builder() @@ -285,9 +244,7 @@ void testConsistentVariadicWithNullableTypes() { // Mix of nullable and non-nullable i64 - should pass (nullability is ignored) assertTrue( - testInputTypesMatch( - List.of(R.I64, N.I64, R.I64, N.I64), args, variadic), + testInputTypesMatch(List.of(R.I64, N.I64, R.I64, N.I64), args, variadic), "Consistent variadic with same types but different nullability should match"); } } - From f8abdcd2fa3998fab45ce11fef956e698c39d1c2 Mon Sep 17 00:00:00 2001 From: Adam Alani Date: Thu, 4 Dec 2025 10:52:40 +0100 Subject: [PATCH 3/7] applying comments --- .../AggregateFunctionInvocation.java | 55 ++++ .../io/substrait/expression/Expression.java | 110 +++++++ .../VariadicParameterConsistencyTest.java | 289 ++++++++++++++++++ .../extension/TypeExtensionTest.java | 22 -- .../extension/VariadicBehaviorTest.java | 33 ++ .../isthmus/expression/FunctionConverter.java | 52 +--- .../VariadicParameterConsistencyTest.java | 250 --------------- 7 files changed, 494 insertions(+), 317 deletions(-) create mode 100644 core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java create mode 100644 core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java delete mode 100644 isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java diff --git a/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java b/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java index 987ffcd23..04edf3268 100644 --- a/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java +++ b/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java @@ -25,6 +25,61 @@ public Type getType() { public abstract Expression.AggregationInvocation invocation(); + /** + * Validates that variadic arguments satisfy the parameter consistency requirement. When + * CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When + * INCONSISTENT, arguments can have different types. + */ + @Value.Check + protected void check() { + SimpleExtension.Function func = declaration(); + java.util.Optional variadic = func.variadic(); + if (!variadic.isPresent()) { + return; + } + + SimpleExtension.VariadicBehavior variadicBehavior = variadic.get(); + if (variadicBehavior.parameterConsistency() + != SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) { + // INCONSISTENT allows different types, so validation passes + return; + } + + // Extract types from arguments (only Expression and Type have types, EnumArg doesn't) + List argumentTypes = + arguments().stream() + .filter(arg -> arg instanceof Expression || arg instanceof Type) + .map( + arg -> { + if (arg instanceof Expression) { + return ((Expression) arg).getType(); + } else { + return (Type) arg; + } + }) + .collect(java.util.stream.Collectors.toList()); + + int fixedArgCount = func.args().size(); + if (argumentTypes.size() <= fixedArgCount) { + // No variadic arguments, validation passes + return; + } + + // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) + int firstVariadicArgIdx = Math.max(variadicBehavior.getMin() - 1, 0); + for (int i = firstVariadicArgIdx; i < argumentTypes.size() - 1; i++) { + Type currentType = argumentTypes.get(i); + Type nextType = argumentTypes.get(i + 1); + // Normalize both types to nullable for comparison (ignoring nullability) + assert io.substrait.type.TypeCreator.asNullable(currentType) + .equals(io.substrait.type.TypeCreator.asNullable(nextType)) + : String.format( + "Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. " + + "Argument at index %d has type %s but argument at index %d has type %s", + i, currentType, i + 1, nextType); + } + } + public static ImmutableAggregateFunctionInvocation.Builder builder() { return ImmutableAggregateFunctionInvocation.builder(); } diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 42c3c5118..ef8ff18b0 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -799,6 +799,61 @@ public Type getType() { return outputType(); } + /** + * Validates that variadic arguments satisfy the parameter consistency requirement. When + * CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When + * INCONSISTENT, arguments can have different types. + */ + @Value.Check + protected void check() { + SimpleExtension.Function func = declaration(); + java.util.Optional variadic = func.variadic(); + if (!variadic.isPresent()) { + return; + } + + SimpleExtension.VariadicBehavior variadicBehavior = variadic.get(); + if (variadicBehavior.parameterConsistency() + != SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) { + // INCONSISTENT allows different types, so validation passes + return; + } + + // Extract types from arguments (only Expression and Type have types, EnumArg doesn't) + List argumentTypes = + arguments().stream() + .filter(arg -> arg instanceof Expression || arg instanceof Type) + .map( + arg -> { + if (arg instanceof Expression) { + return ((Expression) arg).getType(); + } else { + return (Type) arg; + } + }) + .collect(java.util.stream.Collectors.toList()); + + int fixedArgCount = func.args().size(); + if (argumentTypes.size() <= fixedArgCount) { + // No variadic arguments, validation passes + return; + } + + // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) + int firstVariadicArgIdx = Math.max(variadicBehavior.getMin() - 1, 0); + for (int i = firstVariadicArgIdx; i < argumentTypes.size() - 1; i++) { + Type currentType = argumentTypes.get(i); + Type nextType = argumentTypes.get(i + 1); + // Normalize both types to nullable for comparison (ignoring nullability) + assert io.substrait.type.TypeCreator.asNullable(currentType) + .equals(io.substrait.type.TypeCreator.asNullable(nextType)) + : String.format( + "Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. " + + "Argument at index %d has type %s but argument at index %d has type %s", + i, currentType, i + 1, nextType); + } + } + public static ImmutableExpression.ScalarFunctionInvocation.Builder builder() { return ImmutableExpression.ScalarFunctionInvocation.builder(); } @@ -840,6 +895,61 @@ public Type getType() { public abstract AggregationInvocation invocation(); + /** + * Validates that variadic arguments satisfy the parameter consistency requirement. When + * CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When + * INCONSISTENT, arguments can have different types. + */ + @Value.Check + protected void check() { + SimpleExtension.Function func = declaration(); + java.util.Optional variadic = func.variadic(); + if (!variadic.isPresent()) { + return; + } + + SimpleExtension.VariadicBehavior variadicBehavior = variadic.get(); + if (variadicBehavior.parameterConsistency() + != SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) { + // INCONSISTENT allows different types, so validation passes + return; + } + + // Extract types from arguments (only Expression and Type have types, EnumArg doesn't) + List argumentTypes = + arguments().stream() + .filter(arg -> arg instanceof Expression || arg instanceof Type) + .map( + arg -> { + if (arg instanceof Expression) { + return ((Expression) arg).getType(); + } else { + return (Type) arg; + } + }) + .collect(java.util.stream.Collectors.toList()); + + int fixedArgCount = func.args().size(); + if (argumentTypes.size() <= fixedArgCount) { + // No variadic arguments, validation passes + return; + } + + // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) + int firstVariadicArgIdx = Math.max(variadicBehavior.getMin() - 1, 0); + for (int i = firstVariadicArgIdx; i < argumentTypes.size() - 1; i++) { + Type currentType = argumentTypes.get(i); + Type nextType = argumentTypes.get(i + 1); + // Normalize both types to nullable for comparison (ignoring nullability) + assert io.substrait.type.TypeCreator.asNullable(currentType) + .equals(io.substrait.type.TypeCreator.asNullable(nextType)) + : String.format( + "Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. " + + "Argument at index %d has type %s but argument at index %d has type %s", + i, currentType, i + 1, nextType); + } + } + public static ImmutableExpression.WindowFunctionInvocation.Builder builder() { return ImmutableExpression.WindowFunctionInvocation.builder(); } diff --git a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java new file mode 100644 index 000000000..d936f1b44 --- /dev/null +++ b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java @@ -0,0 +1,289 @@ +package io.substrait.expression; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.extension.ImmutableSimpleExtension; +import io.substrait.extension.SimpleExtension; +import io.substrait.function.ParameterizedType; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** Tests for variadic parameter consistency validation in Expression. */ +class VariadicParameterConsistencyTest { + + private static final TypeCreator R = TypeCreator.of(false); + private static final TypeCreator N = TypeCreator.of(true); + + /** + * Helper method to create a ScalarFunctionInvocation and test if it validates correctly. The + * validation happens in the @Value.Check method when the Expression is built. + */ + private Expression.ScalarFunctionInvocation createScalarFunctionInvocation( + List args, + SimpleExtension.VariadicBehavior variadic, + List arguments) { + SimpleExtension.ScalarFunctionVariant declaration = + ImmutableSimpleExtension.ScalarFunctionVariant.builder() + .urn("extension:test:variadic") + .name("test_func") + .args(args) + .variadic( + variadic != null ? java.util.Optional.of(variadic) : java.util.Optional.empty()) + .returnType(R.I64) + .options(java.util.Collections.emptyMap()) + .build(); + + return Expression.ScalarFunctionInvocation.builder() + .declaration(declaration) + .arguments(arguments) + .outputType(R.I64) + .options(java.util.Collections.emptyList()) + .build(); + } + + @Test + void testConsistentVariadicWithSameTypes() { + // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // All variadic arguments are i64 - should pass + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build())), + "Consistent variadic with same types should pass"); + + // All variadic arguments are i64 (with different nullability) - should pass + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).nullable(true).build(), + Expression.I64Literal.builder().value(3).nullable(true).build(), + Expression.I64Literal.builder().value(4).build())), + "Consistent variadic with same types but different nullability should pass"); + } + + @Test + void testConsistentVariadicWithDifferentTypes() { + // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // Variadic arguments have different types - should fail + assertThrows( + AssertionError.class, + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.FP64Literal.builder().value(3.0).build())), + "Consistent variadic with different types should fail"); + + assertThrows( + AssertionError.class, + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I32Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build())), + "Consistent variadic with different types should fail"); + } + + @Test + void testInconsistentVariadicWithDifferentTypes() { + // Function: test_func(i64, any...) with INCONSISTENT parameterConsistency + // When INCONSISTENT, each variadic argument can be a different type, but they all need to + // match the parameterized type constraint (any in this case) + ParameterizedType anyType = + ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); + + List args = + List.of( + SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build(), + SimpleExtension.ValueArgument.builder().value(anyType).name("variadic_arg").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency( + SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT) + .build(); + + // Variadic arguments have different types - should pass with INCONSISTENT + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.FP64Literal.builder().value(3.0).build())), + "Inconsistent variadic with different types should pass"); + + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I32Literal.builder().value(2).build(), + Expression.FP64Literal.builder().value(3.0).build(), + Expression.StrLiteral.builder().value("test").build())), + "Inconsistent variadic with different types should pass"); + } + + @Test + void testConsistentVariadicWithWildcardType() { + // Function: test_func(any, any...) with CONSISTENT parameterConsistency + // The variadic arguments should all have the same concrete type + ParameterizedType anyType = + ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); + + List args = + List.of(SimpleExtension.ValueArgument.builder().value(anyType).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // All variadic arguments are i64 - should pass + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build())), + "Consistent variadic with wildcard type and same concrete types should pass"); + + // Variadic arguments have different types - should fail + assertThrows( + AssertionError.class, + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.FP64Literal.builder().value(3.0).build())), + "Consistent variadic with wildcard type and different concrete types should fail"); + } + + @Test + void testConsistentVariadicWithMinGreaterThanOne() { + // Function: test_func(i64, i64...) with CONSISTENT and min=2 + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(2) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // All variadic arguments are i64 - should pass + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build(), + Expression.I64Literal.builder().value(4).build())), + "Consistent variadic with min=2 and same types should pass"); + + // Variadic arguments have different types - should fail + assertThrows( + AssertionError.class, + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).build(), + Expression.I64Literal.builder().value(3).build(), + Expression.FP64Literal.builder().value(4.0).build())), + "Consistent variadic with min=2 and different types should fail"); + } + + @Test + void testNoVariadicBehavior() { + // Function: test_func(i64) with no variadic behavior + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + // No variadic behavior - should pass regardless of consistency + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, null, List.of(Expression.I64Literal.builder().value(1).build())), + "No variadic behavior should always pass"); + } + + @Test + void testConsistentVariadicWithNullableTypes() { + // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency + List args = + List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); + + SimpleExtension.VariadicBehavior variadic = + ImmutableSimpleExtension.VariadicBehavior.builder() + .min(1) + .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) + .build(); + + // Mix of nullable and non-nullable i64 - should pass (nullability is ignored) + assertDoesNotThrow( + () -> + createScalarFunctionInvocation( + args, + variadic, + List.of( + Expression.I64Literal.builder().value(1).build(), + Expression.I64Literal.builder().value(2).nullable(true).build(), + Expression.I64Literal.builder().value(3).build(), + Expression.I64Literal.builder().value(4).nullable(true).build())), + "Consistent variadic with same types but different nullability should pass"); + } +} + diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index e2dc352ba..9d270bff9 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -98,26 +98,4 @@ void roundtripNumberedAnyTypes() { assertEquals(plan, planReturned); } - @Test - public void testParameterConsistencyLoading() { - String yamlContent = - "urn: extension:test:example\n" - + "scalar_functions:\n" - + " - name: test_func\n" - + " impls:\n" - + " - args:\n" - + " - name: arg1\n" - + " value: string\n" - + " variadic:\n" - + " min: 1\n" - + " parameterConsistency: CONSISTENT\n" - + " return: string\n"; - - SimpleExtension.ExtensionCollection collection = - SimpleExtension.load("test://example", yamlContent); - - assertEquals( - SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT, - collection.scalarFunctions().get(0).variadic().get().parameterConsistency()); - } } diff --git a/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java b/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java new file mode 100644 index 000000000..bc4a7f202 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java @@ -0,0 +1,33 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +/** Tests for VariadicBehavior, particularly parameterConsistency loading from YAML. */ +class VariadicBehaviorTest { + + @Test + void testParameterConsistencyLoading() { + String yamlContent = + "urn: extension:test:example\n" + + "scalar_functions:\n" + + " - name: test_func\n" + + " impls:\n" + + " - args:\n" + + " - name: arg1\n" + + " value: string\n" + + " variadic:\n" + + " min: 1\n" + + " parameterConsistency: CONSISTENT\n" + + " return: string\n"; + + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("test://example", yamlContent); + + assertEquals( + SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT, + collection.scalarFunctions().get(0).variadic().get().parameterConsistency()); + } +} + diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 7c37eeabd..d738ef157 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -20,7 +20,6 @@ import io.substrait.isthmus.expression.FunctionMappings.TypeBasedResolver; import io.substrait.type.Type; import io.substrait.util.Util; - import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -38,14 +37,12 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; - import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlOperator; -import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -247,8 +244,7 @@ private Optional signatureMatch(List inputTypes, Type outputType) { // Make sure that arguments & return are within bounds and match the types if (function.returnType() instanceof ParameterizedType && isMatch(outputType, (ParameterizedType) function.returnType()) - && inputTypesMatchDefinedArguments( - inputTypes, args, function.variadic().orElse(null))) { + && inputTypesMatchDefinedArguments(inputTypes, args)) { return Optional.of(function); } } @@ -260,21 +256,17 @@ && inputTypesMatchDefinedArguments( * Checks to see if the given input types satisfy the function arguments given. Checks that * *
    - *
  • Variadic arguments all have the same input type (when parameterConsistency is - * CONSISTENT) + *
  • Variadic arguments all have the same input type *
  • Matched wildcard arguments (i.e.`any`, `any1`, `any2`, etc) all have the same input * type *
* * @param inputTypes input types to check against arguments * @param args expected arguments as defined in a {@link SimpleExtension.Function} - * @param variadic the variadic behavior to check for consistency, or null if not variadic * @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise */ - boolean inputTypesMatchDefinedArguments( - List inputTypes, - List args, - SimpleExtension.@Nullable VariadicBehavior variadic) { + private boolean inputTypesMatchDefinedArguments( + List inputTypes, List args) { Map> wildcardToType = new HashMap<>(); for (int i = 0; i < inputTypes.size(); i++) { @@ -300,39 +292,9 @@ boolean inputTypesMatchDefinedArguments( } // If all the types match, check if the wildcard types are compatible. - // When parameterConsistency is INCONSISTENT, wildcard types can differ. - // When parameterConsistency is CONSISTENT (or not variadic), wildcard types must be the same. - boolean allowInconsistentWildcards = - variadic != null - && variadic.parameterConsistency() - == SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT; - if (!allowInconsistentWildcards - && !wildcardToType.values().stream().allMatch(s -> s.size() == 1)) { - return false; - } - - // Validate variadic argument consistency - if (variadic != null - && variadic.parameterConsistency() - == SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT - && inputTypes.size() > args.size()) { - // Check that all variadic arguments have the same type (ignoring nullability) - int firstVariadicArgIdx = Math.max(variadic.getMin() - 1, 0); - // Compare consecutive arguments starting from firstVariadicArgIdx - for (int i = firstVariadicArgIdx; i < inputTypes.size() - 1; i++) { - Type currentType = inputTypes.get(i); - Type nextType = inputTypes.get(i + 1); - // Compare types ignoring nullability - check both directions to ensure types match - boolean typesMatch = - currentType.accept(new IgnoreNullableAndParameters(nextType)) - || nextType.accept(new IgnoreNullableAndParameters(currentType)); - if (!typesMatch) { - return false; - } - } - } - - return true; + // TODO: Determine if non-enumerated wildcard types (i.e. `any` as opposed to `any1`) need to + // have the same type. + return wildcardToType.values().stream().allMatch(s -> s.size() == 1); } /** diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java deleted file mode 100644 index 38ff68cde..000000000 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/VariadicParameterConsistencyTest.java +++ /dev/null @@ -1,250 +0,0 @@ -package io.substrait.isthmus.expression; - -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.substrait.extension.ImmutableSimpleExtension; -import io.substrait.extension.SimpleExtension; -import io.substrait.function.ParameterizedType; -import io.substrait.isthmus.PlanTestBase; -import io.substrait.isthmus.TypeConverter; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; -import java.util.List; -import org.junit.jupiter.api.Test; - -/** Tests for variadic parameter consistency validation in FunctionConverter. */ -class VariadicParameterConsistencyTest extends PlanTestBase { - - private static final TypeCreator R = TypeCreator.of(false); - private static final TypeCreator N = TypeCreator.of(true); - - /** - * Helper method to test if input types match defined arguments with variadic behavior. Creates a - * minimal FunctionConverter and FunctionFinder to access the package-private method. - */ - private boolean testInputTypesMatch( - List inputTypes, - List args, - SimpleExtension.VariadicBehavior variadic) { - // Create a minimal test function to use with FunctionConverter - SimpleExtension.ScalarFunctionVariant testFunction = - ImmutableSimpleExtension.ScalarFunctionVariant.builder() - .urn("extension:test:variadic") - .name("test_func") - .args(args) - .variadic( - variadic != null ? java.util.Optional.of(variadic) : java.util.Optional.empty()) - .returnType(R.I64) - .options(java.util.Collections.emptyMap()) - .build(); - - // Create a ScalarFunctionConverter with our test function - ScalarFunctionConverter converter = - new ScalarFunctionConverter( - List.of(testFunction), - java.util.Collections.emptyList(), - typeFactory, - TypeConverter.DEFAULT); - - // Create a test helper that can access the package-private method - TestHelper testHelper = new TestHelper(converter, testFunction); - return testHelper.testInputTypesMatch(inputTypes, args, variadic); - } - - /** - * Test helper class that extends ScalarFunctionConverter to access package-private - * inputTypesMatchDefinedArguments method. - */ - private static class TestHelper extends ScalarFunctionConverter { - private final SimpleExtension.ScalarFunctionVariant testFunction; - - TestHelper(ScalarFunctionConverter parent, SimpleExtension.ScalarFunctionVariant testFunction) { - super( - java.util.Collections.emptyList(), - java.util.Collections.emptyList(), - parent.typeFactory, - parent.typeConverter); - this.testFunction = testFunction; - } - - boolean testInputTypesMatch( - List inputTypes, - List args, - SimpleExtension.VariadicBehavior variadic) { - // Create a minimal FunctionFinder to access the package-private method - // We need a SqlOperator, so create a dummy one - org.apache.calcite.sql.SqlFunction dummyOperator = - new org.apache.calcite.sql.SqlFunction( - "test", - org.apache.calcite.sql.SqlKind.OTHER_FUNCTION, - org.apache.calcite.sql.type.ReturnTypes.explicit( - org.apache.calcite.sql.type.SqlTypeName.BIGINT), - null, - null, - org.apache.calcite.sql.SqlFunctionCategory.USER_DEFINED_FUNCTION); - - // Create a FunctionFinder with the test function so it can compute argRange properly - // FunctionFinder is a protected inner class, so we can access it from a subclass - FunctionFinder finder = new FunctionFinder("test_func", dummyOperator, List.of(testFunction)); - - // Access the package-private method - return finder.inputTypesMatchDefinedArguments(inputTypes, args, variadic); - } - } - - @Test - void testConsistentVariadicWithSameTypes() { - // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency - List args = - List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); - - SimpleExtension.VariadicBehavior variadic = - ImmutableSimpleExtension.VariadicBehavior.builder() - .min(1) - .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) - .build(); - - // All variadic arguments are i64 - should pass - assertTrue( - testInputTypesMatch(List.of(R.I64, R.I64, R.I64), args, variadic), - "Consistent variadic with same types should match"); - - // All variadic arguments are i64 (with different nullability) - should pass - assertTrue( - testInputTypesMatch(List.of(R.I64, R.I64, N.I64, R.I64), args, variadic), - "Consistent variadic with same types but different nullability should match"); - } - - @Test - void testConsistentVariadicWithDifferentTypes() { - // Function: test_func(i64, any...) with CONSISTENT parameterConsistency - List args = - List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); - - SimpleExtension.VariadicBehavior variadic = - ImmutableSimpleExtension.VariadicBehavior.builder() - .min(1) - .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) - .build(); - - // Variadic arguments have different types - should fail - assertFalse( - testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic), - "Consistent variadic with different types should not match"); - - assertFalse( - testInputTypesMatch(List.of(R.I64, R.I32, R.I64), args, variadic), - "Consistent variadic with different types should not match"); - } - - @Test - void testInconsistentVariadicWithDifferentTypes() { - // Function: test_func(i64, any...) with INCONSISTENT parameterConsistency - // When INCONSISTENT, each variadic argument can be a different type, but they all need to - // match the parameterized type constraint (any in this case) - ParameterizedType anyType = - ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); - - List args = - List.of( - SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build(), - SimpleExtension.ValueArgument.builder().value(anyType).name("variadic_arg").build()); - - SimpleExtension.VariadicBehavior variadic = - ImmutableSimpleExtension.VariadicBehavior.builder() - .min(1) - .parameterConsistency( - SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT) - .build(); - - // Variadic arguments have different types - should pass with INCONSISTENT and wildcard type - assertTrue( - testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic), - "Inconsistent variadic with different types should match when using wildcard type"); - - assertTrue( - testInputTypesMatch(List.of(R.I64, R.I32, R.FP64, R.STRING), args, variadic), - "Inconsistent variadic with different types should match when using wildcard type"); - } - - @Test - void testConsistentVariadicWithWildcardType() { - // Function: test_func(any, any...) with CONSISTENT parameterConsistency - // The variadic arguments should all have the same concrete type - ParameterizedType anyType = - ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); - - List args = - List.of(SimpleExtension.ValueArgument.builder().value(anyType).name("arg1").build()); - - SimpleExtension.VariadicBehavior variadic = - ImmutableSimpleExtension.VariadicBehavior.builder() - .min(1) - .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) - .build(); - - // All variadic arguments are i64 - should pass - assertTrue( - testInputTypesMatch(List.of(R.I64, R.I64, R.I64), args, variadic), - "Consistent variadic with wildcard type and same concrete types should match"); - - // Variadic arguments have different types - should fail - assertFalse( - testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic), - "Consistent variadic with wildcard type and different concrete types should not match"); - } - - @Test - void testConsistentVariadicWithMinGreaterThanOne() { - // Function: test_func(i64, i64...) with CONSISTENT and min=2 - List args = - List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); - - SimpleExtension.VariadicBehavior variadic = - ImmutableSimpleExtension.VariadicBehavior.builder() - .min(2) - .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) - .build(); - - // All variadic arguments are i64 - should pass - assertTrue( - testInputTypesMatch(List.of(R.I64, R.I64, R.I64, R.I64), args, variadic), - "Consistent variadic with min=2 and same types should match"); - - // Variadic arguments have different types - should fail - assertFalse( - testInputTypesMatch(List.of(R.I64, R.I64, R.I64, R.FP64), args, variadic), - "Consistent variadic with min=2 and different types should not match"); - } - - @Test - void testNoVariadicBehavior() { - // Function: test_func(i64) with no variadic behavior - List args = - List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); - - // No variadic behavior - should pass regardless of consistency - assertTrue( - testInputTypesMatch(List.of(R.I64), args, null), - "No variadic behavior should always match"); - } - - @Test - void testConsistentVariadicWithNullableTypes() { - // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency - List args = - List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); - - SimpleExtension.VariadicBehavior variadic = - ImmutableSimpleExtension.VariadicBehavior.builder() - .min(1) - .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) - .build(); - - // Mix of nullable and non-nullable i64 - should pass (nullability is ignored) - assertTrue( - testInputTypesMatch(List.of(R.I64, N.I64, R.I64, N.I64), args, variadic), - "Consistent variadic with same types but different nullability should match"); - } -} From 718053ae4ff31801ea9b797c55c92b8b4e231535 Mon Sep 17 00:00:00 2001 From: Adam Alani Date: Thu, 4 Dec 2025 13:18:33 +0100 Subject: [PATCH 4/7] simplify to class --- .../AggregateFunctionInvocation.java | 47 +--------- .../io/substrait/expression/Expression.java | 94 +------------------ ...VariadicParameterConsistencyValidator.java | 74 +++++++++++++++ 3 files changed, 77 insertions(+), 138 deletions(-) create mode 100644 core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java diff --git a/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java b/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java index 04edf3268..2bdfb00e3 100644 --- a/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java +++ b/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java @@ -32,52 +32,7 @@ public Type getType() { */ @Value.Check protected void check() { - SimpleExtension.Function func = declaration(); - java.util.Optional variadic = func.variadic(); - if (!variadic.isPresent()) { - return; - } - - SimpleExtension.VariadicBehavior variadicBehavior = variadic.get(); - if (variadicBehavior.parameterConsistency() - != SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) { - // INCONSISTENT allows different types, so validation passes - return; - } - - // Extract types from arguments (only Expression and Type have types, EnumArg doesn't) - List argumentTypes = - arguments().stream() - .filter(arg -> arg instanceof Expression || arg instanceof Type) - .map( - arg -> { - if (arg instanceof Expression) { - return ((Expression) arg).getType(); - } else { - return (Type) arg; - } - }) - .collect(java.util.stream.Collectors.toList()); - - int fixedArgCount = func.args().size(); - if (argumentTypes.size() <= fixedArgCount) { - // No variadic arguments, validation passes - return; - } - - // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) - int firstVariadicArgIdx = Math.max(variadicBehavior.getMin() - 1, 0); - for (int i = firstVariadicArgIdx; i < argumentTypes.size() - 1; i++) { - Type currentType = argumentTypes.get(i); - Type nextType = argumentTypes.get(i + 1); - // Normalize both types to nullable for comparison (ignoring nullability) - assert io.substrait.type.TypeCreator.asNullable(currentType) - .equals(io.substrait.type.TypeCreator.asNullable(nextType)) - : String.format( - "Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. " - + "Argument at index %d has type %s but argument at index %d has type %s", - i, currentType, i + 1, nextType); - } + VariadicParameterConsistencyValidator.validate(declaration(), arguments()); } public static ImmutableAggregateFunctionInvocation.Builder builder() { diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index ef8ff18b0..f34d94495 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -806,52 +806,7 @@ public Type getType() { */ @Value.Check protected void check() { - SimpleExtension.Function func = declaration(); - java.util.Optional variadic = func.variadic(); - if (!variadic.isPresent()) { - return; - } - - SimpleExtension.VariadicBehavior variadicBehavior = variadic.get(); - if (variadicBehavior.parameterConsistency() - != SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) { - // INCONSISTENT allows different types, so validation passes - return; - } - - // Extract types from arguments (only Expression and Type have types, EnumArg doesn't) - List argumentTypes = - arguments().stream() - .filter(arg -> arg instanceof Expression || arg instanceof Type) - .map( - arg -> { - if (arg instanceof Expression) { - return ((Expression) arg).getType(); - } else { - return (Type) arg; - } - }) - .collect(java.util.stream.Collectors.toList()); - - int fixedArgCount = func.args().size(); - if (argumentTypes.size() <= fixedArgCount) { - // No variadic arguments, validation passes - return; - } - - // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) - int firstVariadicArgIdx = Math.max(variadicBehavior.getMin() - 1, 0); - for (int i = firstVariadicArgIdx; i < argumentTypes.size() - 1; i++) { - Type currentType = argumentTypes.get(i); - Type nextType = argumentTypes.get(i + 1); - // Normalize both types to nullable for comparison (ignoring nullability) - assert io.substrait.type.TypeCreator.asNullable(currentType) - .equals(io.substrait.type.TypeCreator.asNullable(nextType)) - : String.format( - "Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. " - + "Argument at index %d has type %s but argument at index %d has type %s", - i, currentType, i + 1, nextType); - } + VariadicParameterConsistencyValidator.validate(declaration(), arguments()); } public static ImmutableExpression.ScalarFunctionInvocation.Builder builder() { @@ -902,52 +857,7 @@ public Type getType() { */ @Value.Check protected void check() { - SimpleExtension.Function func = declaration(); - java.util.Optional variadic = func.variadic(); - if (!variadic.isPresent()) { - return; - } - - SimpleExtension.VariadicBehavior variadicBehavior = variadic.get(); - if (variadicBehavior.parameterConsistency() - != SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) { - // INCONSISTENT allows different types, so validation passes - return; - } - - // Extract types from arguments (only Expression and Type have types, EnumArg doesn't) - List argumentTypes = - arguments().stream() - .filter(arg -> arg instanceof Expression || arg instanceof Type) - .map( - arg -> { - if (arg instanceof Expression) { - return ((Expression) arg).getType(); - } else { - return (Type) arg; - } - }) - .collect(java.util.stream.Collectors.toList()); - - int fixedArgCount = func.args().size(); - if (argumentTypes.size() <= fixedArgCount) { - // No variadic arguments, validation passes - return; - } - - // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) - int firstVariadicArgIdx = Math.max(variadicBehavior.getMin() - 1, 0); - for (int i = firstVariadicArgIdx; i < argumentTypes.size() - 1; i++) { - Type currentType = argumentTypes.get(i); - Type nextType = argumentTypes.get(i + 1); - // Normalize both types to nullable for comparison (ignoring nullability) - assert io.substrait.type.TypeCreator.asNullable(currentType) - .equals(io.substrait.type.TypeCreator.asNullable(nextType)) - : String.format( - "Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. " - + "Argument at index %d has type %s but argument at index %d has type %s", - i, currentType, i + 1, nextType); - } + VariadicParameterConsistencyValidator.validate(declaration(), arguments()); } public static ImmutableExpression.WindowFunctionInvocation.Builder builder() { diff --git a/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java b/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java new file mode 100644 index 000000000..626e7460d --- /dev/null +++ b/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java @@ -0,0 +1,74 @@ +package io.substrait.expression; + +import io.substrait.extension.SimpleExtension; +import io.substrait.type.Type; +import java.util.List; + +/** + * Helper class for validating variadic parameter consistency in function invocations. Validates that + * when parameterConsistency is CONSISTENT, all variadic arguments have the same type (ignoring + * nullability). + */ +public class VariadicParameterConsistencyValidator { + + /** + * Validates that variadic arguments satisfy the parameter consistency requirement. When + * CONSISTENT, all variadic arguments must have the same type (ignoring nullability). When + * INCONSISTENT, arguments can have different types. + * + * @param func the function declaration + * @param arguments the function arguments to validate + * @throws AssertionError if validation fails + */ + public static void validate( + SimpleExtension.Function func, List arguments) { + java.util.Optional variadic = func.variadic(); + if (!variadic.isPresent()) { + return; + } + + SimpleExtension.VariadicBehavior variadicBehavior = variadic.get(); + if (variadicBehavior.parameterConsistency() + != SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) { + // INCONSISTENT allows different types, so validation passes + return; + } + + // Extract types from arguments (only Expression and Type have types, EnumArg doesn't) + List argumentTypes = + arguments.stream() + .filter(arg -> arg instanceof Expression || arg instanceof Type) + .map( + arg -> { + if (arg instanceof Expression) { + return ((Expression) arg).getType(); + } else { + return (Type) arg; + } + }) + .collect(java.util.stream.Collectors.toList()); + + int fixedArgCount = func.args().size(); + if (argumentTypes.size() <= fixedArgCount) { + // No variadic arguments, validation passes + return; + } + + // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) + int firstVariadicArgIdx = Math.max(variadicBehavior.getMin() - 1, 0); + for (int i = firstVariadicArgIdx; i < argumentTypes.size() - 1; i++) { + Type currentType = argumentTypes.get(i); + Type nextType = argumentTypes.get(i + 1); + // Normalize both types to nullable for comparison (ignoring nullability) + if (!io.substrait.type.TypeCreator.asNullable(currentType) + .equals(io.substrait.type.TypeCreator.asNullable(nextType))) { + throw new AssertionError( + String.format( + "Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. " + + "Argument at index %d has type %s but argument at index %d has type %s", + i, currentType, i + 1, nextType)); + } + } + } +} + From 1eeed901976149d718f09a54ea0a1a1b38c825e2 Mon Sep 17 00:00:00 2001 From: Adam Alani Date: Fri, 5 Dec 2025 10:24:00 +0100 Subject: [PATCH 5/7] nit: comments --- .../expression/VariadicParameterConsistencyTest.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java index d936f1b44..0f6e03823 100644 --- a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java +++ b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java @@ -55,7 +55,6 @@ void testConsistentVariadicWithSameTypes() { .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) .build(); - // All variadic arguments are i64 - should pass assertDoesNotThrow( () -> createScalarFunctionInvocation( @@ -67,7 +66,6 @@ void testConsistentVariadicWithSameTypes() { Expression.I64Literal.builder().value(3).build())), "Consistent variadic with same types should pass"); - // All variadic arguments are i64 (with different nullability) - should pass assertDoesNotThrow( () -> createScalarFunctionInvocation( @@ -93,7 +91,6 @@ void testConsistentVariadicWithDifferentTypes() { .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) .build(); - // Variadic arguments have different types - should fail assertThrows( AssertionError.class, () -> @@ -139,7 +136,6 @@ void testInconsistentVariadicWithDifferentTypes() { SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT) .build(); - // Variadic arguments have different types - should pass with INCONSISTENT assertDoesNotThrow( () -> createScalarFunctionInvocation( @@ -180,7 +176,6 @@ void testConsistentVariadicWithWildcardType() { .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) .build(); - // All variadic arguments are i64 - should pass assertDoesNotThrow( () -> createScalarFunctionInvocation( @@ -192,7 +187,6 @@ void testConsistentVariadicWithWildcardType() { Expression.I64Literal.builder().value(3).build())), "Consistent variadic with wildcard type and same concrete types should pass"); - // Variadic arguments have different types - should fail assertThrows( AssertionError.class, () -> @@ -218,7 +212,6 @@ void testConsistentVariadicWithMinGreaterThanOne() { .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) .build(); - // All variadic arguments are i64 - should pass assertDoesNotThrow( () -> createScalarFunctionInvocation( @@ -231,7 +224,6 @@ void testConsistentVariadicWithMinGreaterThanOne() { Expression.I64Literal.builder().value(4).build())), "Consistent variadic with min=2 and same types should pass"); - // Variadic arguments have different types - should fail assertThrows( AssertionError.class, () -> @@ -252,7 +244,6 @@ void testNoVariadicBehavior() { List args = List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); - // No variadic behavior - should pass regardless of consistency assertDoesNotThrow( () -> createScalarFunctionInvocation( @@ -272,7 +263,6 @@ void testConsistentVariadicWithNullableTypes() { .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) .build(); - // Mix of nullable and non-nullable i64 - should pass (nullability is ignored) assertDoesNotThrow( () -> createScalarFunctionInvocation( From 0d8c1b6a82f2d7dc3cfb95352a02033311e4613f Mon Sep 17 00:00:00 2001 From: Adam Alani Date: Fri, 5 Dec 2025 11:10:27 +0100 Subject: [PATCH 6/7] type comparison simplification --- .../VariadicParameterConsistencyValidator.java | 14 +++++++------- core/src/main/java/io/substrait/type/Type.java | 10 ++++++++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java b/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java index 626e7460d..5dcb54d50 100644 --- a/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java +++ b/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java @@ -55,18 +55,18 @@ public static void validate( } // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) - int firstVariadicArgIdx = Math.max(variadicBehavior.getMin() - 1, 0); - for (int i = firstVariadicArgIdx; i < argumentTypes.size() - 1; i++) { + // Compare all variadic arguments to the first one for more informative error messages + // Variadic arguments start after the fixed arguments + int firstVariadicArgIdx = fixedArgCount + Math.max(variadicBehavior.getMin() - 1, 0); + Type firstVariadicType = argumentTypes.get(firstVariadicArgIdx); + for (int i = firstVariadicArgIdx + 1; i < argumentTypes.size(); i++) { Type currentType = argumentTypes.get(i); - Type nextType = argumentTypes.get(i + 1); - // Normalize both types to nullable for comparison (ignoring nullability) - if (!io.substrait.type.TypeCreator.asNullable(currentType) - .equals(io.substrait.type.TypeCreator.asNullable(nextType))) { + if (!firstVariadicType.equalsIgnoringNullability(currentType)) { throw new AssertionError( String.format( "Variadic arguments must have consistent types when parameterConsistency is CONSISTENT. " + "Argument at index %d has type %s but argument at index %d has type %s", - i, currentType, i + 1, nextType)); + firstVariadicArgIdx, firstVariadicType, i, currentType)); } } } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index aaf97aa12..15606f5a3 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -25,6 +25,16 @@ default R accept( return fnArgVisitor.visitType(fnDef, argIdx, this, context); } + /** + * Compares this type with another type, ignoring nullability differences. + * + * @param other the type to compare with + * @return true if the types are equal when both are treated as nullable + */ + default boolean equalsIgnoringNullability(Type other) { + return TypeCreator.asNullable(this).equals(TypeCreator.asNullable(other)); + } + @Value.Immutable abstract class Bool implements Type { public static ImmutableType.Bool.Builder builder() { From e0ca93b35c6985f2939acb622d24629f84c557bd Mon Sep 17 00:00:00 2001 From: Adam Alani Date: Tue, 9 Dec 2025 10:59:29 +0100 Subject: [PATCH 7/7] fix: lint and tests --- ...VariadicParameterConsistencyValidator.java | 47 ++++++++++++++----- .../VariadicParameterConsistencyTest.java | 2 - .../extension/TypeExtensionTest.java | 1 - .../extension/VariadicBehaviorTest.java | 1 - 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java b/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java index 5dcb54d50..8619d9274 100644 --- a/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java +++ b/core/src/main/java/io/substrait/expression/VariadicParameterConsistencyValidator.java @@ -3,10 +3,12 @@ import io.substrait.extension.SimpleExtension; import io.substrait.type.Type; import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; /** - * Helper class for validating variadic parameter consistency in function invocations. Validates that - * when parameterConsistency is CONSISTENT, all variadic arguments have the same type (ignoring + * Helper class for validating variadic parameter consistency in function invocations. Validates + * that when parameterConsistency is CONSISTENT, all variadic arguments have the same type (ignoring * nullability). */ public class VariadicParameterConsistencyValidator { @@ -20,9 +22,8 @@ public class VariadicParameterConsistencyValidator { * @param arguments the function arguments to validate * @throws AssertionError if validation fails */ - public static void validate( - SimpleExtension.Function func, List arguments) { - java.util.Optional variadic = func.variadic(); + public static void validate(SimpleExtension.Function func, List arguments) { + Optional variadic = func.variadic(); if (!variadic.isPresent()) { return; } @@ -31,6 +32,17 @@ public static void validate( if (variadicBehavior.parameterConsistency() != SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) { // INCONSISTENT allows different types, so validation passes + // TODO: Even when parameterConsistency is INCONSISTENT, there can be implicit constraints + // across variadic parameters due to type parameters. For example, consider a function with: + // args: [value: "decimal", variadic: {min: 1, parameterConsistency: INCONSISTENT}] + // return: "decimal<38,S>" + // In this case, while the precision P can vary across variadic arguments, the scale S must + // be consistent across all variadic arguments (since it's used in the return type). The + // current implementation doesn't validate these type parameter constraints. According to + // the spec: "Each argument can be any possible concrete type afforded by the bounds of any + // parameter defined in the arguments specification." This means we need to check that type + // parameters that appear in the return type (or are otherwise constrained) are consistent + // across variadic arguments, even when parameterConsistency is INCONSISTENT. return; } @@ -46,18 +58,32 @@ public static void validate( return (Type) arg; } }) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); - int fixedArgCount = func.args().size(); - if (argumentTypes.size() <= fixedArgCount) { + // Count how many Expression/Type arguments are in the fixed arguments (before variadic) + // Note: func.args() includes all argument types (Expression, Type, EnumArg), but we only + // care about Expression/Type arguments for type consistency checking + int fixedTypeArgCount = 0; + for (int i = 0; i < func.args().size() && i < arguments.size(); i++) { + FunctionArg arg = arguments.get(i); + if (arg instanceof Expression || arg instanceof Type) { + fixedTypeArgCount++; + } + } + + if (argumentTypes.size() <= fixedTypeArgCount) { // No variadic arguments, validation passes return; } // For CONSISTENT, all variadic arguments must have the same type (ignoring nullability) // Compare all variadic arguments to the first one for more informative error messages - // Variadic arguments start after the fixed arguments - int firstVariadicArgIdx = fixedArgCount + Math.max(variadicBehavior.getMin() - 1, 0); + // Variadic arguments start immediately after the fixed arguments + int firstVariadicArgIdx = fixedTypeArgCount; + if (firstVariadicArgIdx >= argumentTypes.size()) { + // Not enough variadic arguments provided, validation passes + return; + } Type firstVariadicType = argumentTypes.get(firstVariadicArgIdx); for (int i = firstVariadicArgIdx + 1; i < argumentTypes.size(); i++) { Type currentType = argumentTypes.get(i); @@ -71,4 +97,3 @@ public static void validate( } } } - diff --git a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java index 0f6e03823..61d74f395 100644 --- a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java +++ b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java @@ -14,7 +14,6 @@ class VariadicParameterConsistencyTest { private static final TypeCreator R = TypeCreator.of(false); - private static final TypeCreator N = TypeCreator.of(true); /** * Helper method to create a ScalarFunctionInvocation and test if it validates correctly. The @@ -276,4 +275,3 @@ void testConsistentVariadicWithNullableTypes() { "Consistent variadic with same types but different nullability should pass"); } } - diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index 9d270bff9..abb4008d5 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -97,5 +97,4 @@ void roundtripNumberedAnyTypes() { Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } - } diff --git a/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java b/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java index bc4a7f202..e4e9d2a7a 100644 --- a/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java +++ b/core/src/test/java/io/substrait/extension/VariadicBehaviorTest.java @@ -30,4 +30,3 @@ void testParameterConsistencyLoading() { collection.scalarFunctions().get(0).variadic().get().parameterConsistency()); } } -