-
Notifications
You must be signed in to change notification settings - Fork 93
feat: handle parameterConsistency option in YAML extensions #624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
e54b4f2
033f093
f8abdcd
718053a
1eeed90
0d8c1b6
e0ca93b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| import io.substrait.isthmus.expression.FunctionMappings.TypeBasedResolver; | ||
| import io.substrait.type.Type; | ||
| import io.substrait.util.Util; | ||
|
|
||
| import java.util.ArrayList; | ||
| import java.util.Collection; | ||
| import java.util.Collections; | ||
|
|
@@ -37,12 +38,14 @@ | |
| import java.util.stream.Collectors; | ||
| import java.util.stream.IntStream; | ||
| import java.util.stream.Stream; | ||
|
|
||
| import org.apache.calcite.rel.type.RelDataType; | ||
| import org.apache.calcite.rel.type.RelDataTypeFactory; | ||
| import org.apache.calcite.rex.RexBuilder; | ||
| import org.apache.calcite.rex.RexLiteral; | ||
| import org.apache.calcite.rex.RexNode; | ||
| import org.apache.calcite.sql.SqlOperator; | ||
| import org.jspecify.annotations.Nullable; | ||
| import org.slf4j.Logger; | ||
| import org.slf4j.LoggerFactory; | ||
|
|
||
|
|
@@ -244,7 +247,8 @@ private Optional<F> signatureMatch(List<Type> inputTypes, Type outputType) { | |
| // Make sure that arguments & return are within bounds and match the types | ||
| if (function.returnType() instanceof ParameterizedType | ||
| && isMatch(outputType, (ParameterizedType) function.returnType()) | ||
| && inputTypesMatchDefinedArguments(inputTypes, args)) { | ||
| && inputTypesMatchDefinedArguments( | ||
| inputTypes, args, function.variadic().orElse(null))) { | ||
| return Optional.of(function); | ||
| } | ||
| } | ||
|
|
@@ -256,17 +260,21 @@ && inputTypesMatchDefinedArguments(inputTypes, args)) { | |
| * Checks to see if the given input types satisfy the function arguments given. Checks that | ||
| * | ||
| * <ul> | ||
| * <li>Variadic arguments all have the same input type | ||
| * <li>Variadic arguments all have the same input type (when parameterConsistency is | ||
| * CONSISTENT) | ||
| * <li>Matched wildcard arguments (i.e.`any`, `any1`, `any2`, etc) all have the same input | ||
| * type | ||
| * </ul> | ||
| * | ||
| * @param inputTypes input types to check against arguments | ||
| * @param args expected arguments as defined in a {@link SimpleExtension.Function} | ||
| * @param variadic the variadic behavior to check for consistency, or null if not variadic | ||
| * @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise | ||
| */ | ||
| private boolean inputTypesMatchDefinedArguments( | ||
| List<Type> inputTypes, List<SimpleExtension.Argument> args) { | ||
| boolean inputTypesMatchDefinedArguments( | ||
| List<Type> inputTypes, | ||
| List<SimpleExtension.Argument> args, | ||
| SimpleExtension.@Nullable VariadicBehavior variadic) { | ||
|
|
||
| Map<String, Set<Type>> wildcardToType = new HashMap<>(); | ||
| for (int i = 0; i < inputTypes.size(); i++) { | ||
|
|
@@ -292,9 +300,39 @@ private boolean inputTypesMatchDefinedArguments( | |
| } | ||
|
|
||
| // If all the types match, check if the wildcard types are compatible. | ||
| // TODO: Determine if non-enumerated wildcard types (i.e. `any` as opposed to `any1`) need to | ||
| // have the same type. | ||
| return wildcardToType.values().stream().allMatch(s -> s.size() == 1); | ||
| // When parameterConsistency is INCONSISTENT, wildcard types can differ. | ||
| // When parameterConsistency is CONSISTENT (or not variadic), wildcard types must be the same. | ||
|
||
| boolean allowInconsistentWildcards = | ||
| variadic != null | ||
| && variadic.parameterConsistency() | ||
| == SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT; | ||
| if (!allowInconsistentWildcards | ||
| && !wildcardToType.values().stream().allMatch(s -> s.size() == 1)) { | ||
| return false; | ||
| } | ||
|
|
||
| // Validate variadic argument consistency | ||
| if (variadic != null | ||
| && variadic.parameterConsistency() | ||
| == SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT | ||
| && inputTypes.size() > args.size()) { | ||
| // Check that all variadic arguments have the same type (ignoring nullability) | ||
| int firstVariadicArgIdx = Math.max(variadic.getMin() - 1, 0); | ||
| // Compare consecutive arguments starting from firstVariadicArgIdx | ||
| for (int i = firstVariadicArgIdx; i < inputTypes.size() - 1; i++) { | ||
| Type currentType = inputTypes.get(i); | ||
| Type nextType = inputTypes.get(i + 1); | ||
| // Compare types ignoring nullability - check both directions to ensure types match | ||
| boolean typesMatch = | ||
| currentType.accept(new IgnoreNullableAndParameters(nextType)) | ||
| || nextType.accept(new IgnoreNullableAndParameters(currentType)); | ||
| if (!typesMatch) { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,250 @@ | ||
| package io.substrait.isthmus.expression; | ||
|
|
||
| import static org.junit.jupiter.api.Assertions.assertFalse; | ||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||
|
|
||
| import io.substrait.extension.ImmutableSimpleExtension; | ||
| import io.substrait.extension.SimpleExtension; | ||
| import io.substrait.function.ParameterizedType; | ||
| import io.substrait.isthmus.PlanTestBase; | ||
| import io.substrait.isthmus.TypeConverter; | ||
| import io.substrait.type.Type; | ||
| import io.substrait.type.TypeCreator; | ||
| import java.util.List; | ||
| import org.junit.jupiter.api.Test; | ||
|
|
||
| /** Tests for variadic parameter consistency validation in FunctionConverter. */ | ||
| class VariadicParameterConsistencyTest extends PlanTestBase { | ||
Adam-Alani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| private static final TypeCreator R = TypeCreator.of(false); | ||
| private static final TypeCreator N = TypeCreator.of(true); | ||
|
|
||
| /** | ||
| * Helper method to test if input types match defined arguments with variadic behavior. Creates a | ||
| * minimal FunctionConverter and FunctionFinder to access the package-private method. | ||
| */ | ||
| private boolean testInputTypesMatch( | ||
| List<Type> inputTypes, | ||
| List<SimpleExtension.Argument> args, | ||
| SimpleExtension.VariadicBehavior variadic) { | ||
| // Create a minimal test function to use with FunctionConverter | ||
| SimpleExtension.ScalarFunctionVariant testFunction = | ||
| ImmutableSimpleExtension.ScalarFunctionVariant.builder() | ||
| .urn("extension:test:variadic") | ||
| .name("test_func") | ||
| .args(args) | ||
| .variadic( | ||
| variadic != null ? java.util.Optional.of(variadic) : java.util.Optional.empty()) | ||
| .returnType(R.I64) | ||
| .options(java.util.Collections.emptyMap()) | ||
| .build(); | ||
|
|
||
| // Create a ScalarFunctionConverter with our test function | ||
| ScalarFunctionConverter converter = | ||
| new ScalarFunctionConverter( | ||
| List.of(testFunction), | ||
| java.util.Collections.emptyList(), | ||
| typeFactory, | ||
| TypeConverter.DEFAULT); | ||
|
|
||
| // Create a test helper that can access the package-private method | ||
| TestHelper testHelper = new TestHelper(converter, testFunction); | ||
| return testHelper.testInputTypesMatch(inputTypes, args, variadic); | ||
| } | ||
|
|
||
| /** | ||
| * Test helper class that extends ScalarFunctionConverter to access package-private | ||
| * inputTypesMatchDefinedArguments method. | ||
| */ | ||
| private static class TestHelper extends ScalarFunctionConverter { | ||
| private final SimpleExtension.ScalarFunctionVariant testFunction; | ||
|
|
||
| TestHelper(ScalarFunctionConverter parent, SimpleExtension.ScalarFunctionVariant testFunction) { | ||
| super( | ||
| java.util.Collections.emptyList(), | ||
| java.util.Collections.emptyList(), | ||
| parent.typeFactory, | ||
| parent.typeConverter); | ||
| this.testFunction = testFunction; | ||
| } | ||
|
|
||
| boolean testInputTypesMatch( | ||
| List<Type> inputTypes, | ||
| List<SimpleExtension.Argument> args, | ||
| SimpleExtension.VariadicBehavior variadic) { | ||
| // Create a minimal FunctionFinder to access the package-private method | ||
| // We need a SqlOperator, so create a dummy one | ||
| org.apache.calcite.sql.SqlFunction dummyOperator = | ||
| new org.apache.calcite.sql.SqlFunction( | ||
| "test", | ||
| org.apache.calcite.sql.SqlKind.OTHER_FUNCTION, | ||
| org.apache.calcite.sql.type.ReturnTypes.explicit( | ||
| org.apache.calcite.sql.type.SqlTypeName.BIGINT), | ||
| null, | ||
| null, | ||
| org.apache.calcite.sql.SqlFunctionCategory.USER_DEFINED_FUNCTION); | ||
|
|
||
| // Create a FunctionFinder with the test function so it can compute argRange properly | ||
| // FunctionFinder is a protected inner class, so we can access it from a subclass | ||
| FunctionFinder finder = new FunctionFinder("test_func", dummyOperator, List.of(testFunction)); | ||
|
|
||
| // Access the package-private method | ||
| return finder.inputTypesMatchDefinedArguments(inputTypes, args, variadic); | ||
| } | ||
| } | ||
|
|
||
| @Test | ||
| void testConsistentVariadicWithSameTypes() { | ||
| // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency | ||
| List<SimpleExtension.Argument> args = | ||
| List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); | ||
|
|
||
| SimpleExtension.VariadicBehavior variadic = | ||
| ImmutableSimpleExtension.VariadicBehavior.builder() | ||
| .min(1) | ||
| .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) | ||
| .build(); | ||
|
|
||
| // All variadic arguments are i64 - should pass | ||
| assertTrue( | ||
| testInputTypesMatch(List.of(R.I64, R.I64, R.I64), args, variadic), | ||
| "Consistent variadic with same types should match"); | ||
|
|
||
| // All variadic arguments are i64 (with different nullability) - should pass | ||
| assertTrue( | ||
| testInputTypesMatch(List.of(R.I64, R.I64, N.I64, R.I64), args, variadic), | ||
| "Consistent variadic with same types but different nullability should match"); | ||
| } | ||
|
|
||
| @Test | ||
| void testConsistentVariadicWithDifferentTypes() { | ||
| // Function: test_func(i64, any...) with CONSISTENT parameterConsistency | ||
| List<SimpleExtension.Argument> args = | ||
| List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); | ||
|
|
||
| SimpleExtension.VariadicBehavior variadic = | ||
| ImmutableSimpleExtension.VariadicBehavior.builder() | ||
| .min(1) | ||
| .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) | ||
| .build(); | ||
|
|
||
| // Variadic arguments have different types - should fail | ||
| assertFalse( | ||
| testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic), | ||
| "Consistent variadic with different types should not match"); | ||
|
|
||
| assertFalse( | ||
| testInputTypesMatch(List.of(R.I64, R.I32, R.I64), args, variadic), | ||
| "Consistent variadic with different types should not match"); | ||
| } | ||
|
|
||
| @Test | ||
| void testInconsistentVariadicWithDifferentTypes() { | ||
| // Function: test_func(i64, any...) with INCONSISTENT parameterConsistency | ||
| // When INCONSISTENT, each variadic argument can be a different type, but they all need to | ||
| // match the parameterized type constraint (any in this case) | ||
| ParameterizedType anyType = | ||
| ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); | ||
|
|
||
| List<SimpleExtension.Argument> args = | ||
| List.of( | ||
| SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build(), | ||
| SimpleExtension.ValueArgument.builder().value(anyType).name("variadic_arg").build()); | ||
|
|
||
| SimpleExtension.VariadicBehavior variadic = | ||
| ImmutableSimpleExtension.VariadicBehavior.builder() | ||
| .min(1) | ||
| .parameterConsistency( | ||
| SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT) | ||
| .build(); | ||
|
|
||
| // Variadic arguments have different types - should pass with INCONSISTENT and wildcard type | ||
| assertTrue( | ||
| testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic), | ||
| "Inconsistent variadic with different types should match when using wildcard type"); | ||
|
|
||
| assertTrue( | ||
| testInputTypesMatch(List.of(R.I64, R.I32, R.FP64, R.STRING), args, variadic), | ||
| "Inconsistent variadic with different types should match when using wildcard type"); | ||
| } | ||
|
|
||
| @Test | ||
| void testConsistentVariadicWithWildcardType() { | ||
| // Function: test_func(any, any...) with CONSISTENT parameterConsistency | ||
| // The variadic arguments should all have the same concrete type | ||
| ParameterizedType anyType = | ||
| ParameterizedType.StringLiteral.builder().value("any").nullable(false).build(); | ||
|
|
||
| List<SimpleExtension.Argument> args = | ||
| List.of(SimpleExtension.ValueArgument.builder().value(anyType).name("arg1").build()); | ||
|
|
||
| SimpleExtension.VariadicBehavior variadic = | ||
| ImmutableSimpleExtension.VariadicBehavior.builder() | ||
| .min(1) | ||
| .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) | ||
| .build(); | ||
|
|
||
| // All variadic arguments are i64 - should pass | ||
| assertTrue( | ||
| testInputTypesMatch(List.of(R.I64, R.I64, R.I64), args, variadic), | ||
| "Consistent variadic with wildcard type and same concrete types should match"); | ||
|
|
||
| // Variadic arguments have different types - should fail | ||
| assertFalse( | ||
| testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic), | ||
| "Consistent variadic with wildcard type and different concrete types should not match"); | ||
| } | ||
|
|
||
| @Test | ||
| void testConsistentVariadicWithMinGreaterThanOne() { | ||
| // Function: test_func(i64, i64...) with CONSISTENT and min=2 | ||
| List<SimpleExtension.Argument> args = | ||
| List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); | ||
|
|
||
| SimpleExtension.VariadicBehavior variadic = | ||
| ImmutableSimpleExtension.VariadicBehavior.builder() | ||
| .min(2) | ||
| .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) | ||
| .build(); | ||
|
|
||
| // All variadic arguments are i64 - should pass | ||
| assertTrue( | ||
| testInputTypesMatch(List.of(R.I64, R.I64, R.I64, R.I64), args, variadic), | ||
| "Consistent variadic with min=2 and same types should match"); | ||
|
|
||
| // Variadic arguments have different types - should fail | ||
| assertFalse( | ||
| testInputTypesMatch(List.of(R.I64, R.I64, R.I64, R.FP64), args, variadic), | ||
| "Consistent variadic with min=2 and different types should not match"); | ||
| } | ||
|
|
||
| @Test | ||
| void testNoVariadicBehavior() { | ||
| // Function: test_func(i64) with no variadic behavior | ||
| List<SimpleExtension.Argument> args = | ||
| List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); | ||
|
|
||
| // No variadic behavior - should pass regardless of consistency | ||
| assertTrue( | ||
| testInputTypesMatch(List.of(R.I64), args, null), | ||
| "No variadic behavior should always match"); | ||
| } | ||
|
|
||
| @Test | ||
| void testConsistentVariadicWithNullableTypes() { | ||
| // Function: test_func(i64, i64...) with CONSISTENT parameterConsistency | ||
| List<SimpleExtension.Argument> args = | ||
| List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build()); | ||
|
|
||
| SimpleExtension.VariadicBehavior variadic = | ||
| ImmutableSimpleExtension.VariadicBehavior.builder() | ||
| .min(1) | ||
| .parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT) | ||
| .build(); | ||
|
|
||
| // Mix of nullable and non-nullable i64 - should pass (nullability is ignored) | ||
| assertTrue( | ||
| testInputTypesMatch(List.of(R.I64, N.I64, R.I64, N.I64), args, variadic), | ||
| "Consistent variadic with same types but different nullability should match"); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a reasonble default, because I think it's what most people expect in practice and it's also the first value in the enumeration.
We can formalize this in the spec more concretely.