Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ enum ParameterConsistency {
INCONSISTENT
}

@Value.Default
default ParameterConsistency parameterConsistency() {
return ParameterConsistency.CONSISTENT;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a reasonble default, because I think it's what most people expect in practice and it's also the first value in the enumeration.

We can formalize this in the spec more concretely.

}
Expand Down
23 changes: 23 additions & 0 deletions core/src/test/java/io/substrait/extension/TypeExtensionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,27 @@ void roundtripNumberedAnyTypes() {
Plan planReturned = protoPlanConverter.from(protoPlan);
assertEquals(plan, planReturned);
}

@Test
public void testParameterConsistencyLoading() {
String yamlContent =
"urn: extension:test:example\n"
+ "scalar_functions:\n"
+ " - name: test_func\n"
+ " impls:\n"
+ " - args:\n"
+ " - name: arg1\n"
+ " value: string\n"
+ " variadic:\n"
+ " min: 1\n"
+ " parameterConsistency: CONSISTENT\n"
+ " return: string\n";

SimpleExtension.ExtensionCollection collection =
SimpleExtension.load("test://example", yamlContent);

assertEquals(
SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT,
collection.scalarFunctions().get(0).variadic().get().parameterConsistency());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.substrait.isthmus.expression.FunctionMappings.TypeBasedResolver;
import io.substrait.type.Type;
import io.substrait.util.Util;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -37,12 +38,14 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -244,7 +247,8 @@ private Optional<F> signatureMatch(List<Type> inputTypes, Type outputType) {
// Make sure that arguments & return are within bounds and match the types
if (function.returnType() instanceof ParameterizedType
&& isMatch(outputType, (ParameterizedType) function.returnType())
&& inputTypesMatchDefinedArguments(inputTypes, args)) {
&& inputTypesMatchDefinedArguments(
inputTypes, args, function.variadic().orElse(null))) {
return Optional.of(function);
}
}
Expand All @@ -256,17 +260,21 @@ && inputTypesMatchDefinedArguments(inputTypes, args)) {
* Checks to see if the given input types satisfy the function arguments given. Checks that
*
* <ul>
* <li>Variadic arguments all have the same input type
* <li>Variadic arguments all have the same input type (when parameterConsistency is
* CONSISTENT)
* <li>Matched wildcard arguments (i.e.`any`, `any1`, `any2`, etc) all have the same input
* type
* </ul>
*
* @param inputTypes input types to check against arguments
* @param args expected arguments as defined in a {@link SimpleExtension.Function}
* @param variadic the variadic behavior to check for consistency, or null if not variadic
* @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise
*/
private boolean inputTypesMatchDefinedArguments(
List<Type> inputTypes, List<SimpleExtension.Argument> args) {
boolean inputTypesMatchDefinedArguments(
List<Type> inputTypes,
List<SimpleExtension.Argument> args,
SimpleExtension.@Nullable VariadicBehavior variadic) {

Map<String, Set<Type>> wildcardToType = new HashMap<>();
for (int i = 0; i < inputTypes.size(); i++) {
Expand All @@ -292,9 +300,39 @@ private boolean inputTypesMatchDefinedArguments(
}

// If all the types match, check if the wildcard types are compatible.
// TODO: Determine if non-enumerated wildcard types (i.e. `any` as opposed to `any1`) need to
// have the same type.
return wildcardToType.values().stream().allMatch(s -> s.size() == 1);
// When parameterConsistency is INCONSISTENT, wildcard types can differ.
// When parameterConsistency is CONSISTENT (or not variadic), wildcard types must be the same.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, the parameterConsistency behaviour only applies to the variadic arguments. It's not related to how we typecheck the enumerated wildcard types (i.e any1, any2, etc). From the (limited) documentation on it:

When the last argument of a function is variadic and declares a type parameter e.g. fn(A, B, C...), the C parameter can be marked as either consistent or inconsistent. If marked as consistent, the function can only be bound to arguments where all the C types are the same concrete type. If marked as inconsistent, each unique C can be bound to a different type within the constraints of what T allows.

CONSISTENT means that the types of all the variadic arguments have to be the same. INCONSISTENT means... I'm not sure tbh because in

each unique C can be bound to a different type within the constraints of what T allows.

I don't know what T is. This is something I can follow up with the community about.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I realize that this is attempting to handle the second part of the ticket I linked (#622).

However, I don't believe this is the right place to handle it. Considering that parameterConsistency has meaning regardless of calcite, those checks should really be inside of core/ IMO.

boolean allowInconsistentWildcards =
variadic != null
&& variadic.parameterConsistency()
== SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT;
if (!allowInconsistentWildcards
&& !wildcardToType.values().stream().allMatch(s -> s.size() == 1)) {
return false;
}

// Validate variadic argument consistency
if (variadic != null
&& variadic.parameterConsistency()
== SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT
&& inputTypes.size() > args.size()) {
// Check that all variadic arguments have the same type (ignoring nullability)
int firstVariadicArgIdx = Math.max(variadic.getMin() - 1, 0);
// Compare consecutive arguments starting from firstVariadicArgIdx
for (int i = firstVariadicArgIdx; i < inputTypes.size() - 1; i++) {
Type currentType = inputTypes.get(i);
Type nextType = inputTypes.get(i + 1);
// Compare types ignoring nullability - check both directions to ensure types match
boolean typesMatch =
currentType.accept(new IgnoreNullableAndParameters(nextType))
|| nextType.accept(new IgnoreNullableAndParameters(currentType));
if (!typesMatch) {
return false;
}
}
}

return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
package io.substrait.isthmus.expression;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.substrait.extension.ImmutableSimpleExtension;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ParameterizedType;
import io.substrait.isthmus.PlanTestBase;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import org.junit.jupiter.api.Test;

/** Tests for variadic parameter consistency validation in FunctionConverter. */
class VariadicParameterConsistencyTest extends PlanTestBase {

private static final TypeCreator R = TypeCreator.of(false);
private static final TypeCreator N = TypeCreator.of(true);

/**
* Helper method to test if input types match defined arguments with variadic behavior. Creates a
* minimal FunctionConverter and FunctionFinder to access the package-private method.
*/
private boolean testInputTypesMatch(
List<Type> inputTypes,
List<SimpleExtension.Argument> args,
SimpleExtension.VariadicBehavior variadic) {
// Create a minimal test function to use with FunctionConverter
SimpleExtension.ScalarFunctionVariant testFunction =
ImmutableSimpleExtension.ScalarFunctionVariant.builder()
.urn("extension:test:variadic")
.name("test_func")
.args(args)
.variadic(
variadic != null ? java.util.Optional.of(variadic) : java.util.Optional.empty())
.returnType(R.I64)
.options(java.util.Collections.emptyMap())
.build();

// Create a ScalarFunctionConverter with our test function
ScalarFunctionConverter converter =
new ScalarFunctionConverter(
List.of(testFunction),
java.util.Collections.emptyList(),
typeFactory,
TypeConverter.DEFAULT);

// Create a test helper that can access the package-private method
TestHelper testHelper = new TestHelper(converter, testFunction);
return testHelper.testInputTypesMatch(inputTypes, args, variadic);
}

/**
* Test helper class that extends ScalarFunctionConverter to access package-private
* inputTypesMatchDefinedArguments method.
*/
private static class TestHelper extends ScalarFunctionConverter {
private final SimpleExtension.ScalarFunctionVariant testFunction;

TestHelper(ScalarFunctionConverter parent, SimpleExtension.ScalarFunctionVariant testFunction) {
super(
java.util.Collections.emptyList(),
java.util.Collections.emptyList(),
parent.typeFactory,
parent.typeConverter);
this.testFunction = testFunction;
}

boolean testInputTypesMatch(
List<Type> inputTypes,
List<SimpleExtension.Argument> args,
SimpleExtension.VariadicBehavior variadic) {
// Create a minimal FunctionFinder to access the package-private method
// We need a SqlOperator, so create a dummy one
org.apache.calcite.sql.SqlFunction dummyOperator =
new org.apache.calcite.sql.SqlFunction(
"test",
org.apache.calcite.sql.SqlKind.OTHER_FUNCTION,
org.apache.calcite.sql.type.ReturnTypes.explicit(
org.apache.calcite.sql.type.SqlTypeName.BIGINT),
null,
null,
org.apache.calcite.sql.SqlFunctionCategory.USER_DEFINED_FUNCTION);

// Create a FunctionFinder with the test function so it can compute argRange properly
// FunctionFinder is a protected inner class, so we can access it from a subclass
FunctionFinder finder = new FunctionFinder("test_func", dummyOperator, List.of(testFunction));

// Access the package-private method
return finder.inputTypesMatchDefinedArguments(inputTypes, args, variadic);
}
}

@Test
void testConsistentVariadicWithSameTypes() {
// Function: test_func(i64, i64...) with CONSISTENT parameterConsistency
List<SimpleExtension.Argument> args =
List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build());

SimpleExtension.VariadicBehavior variadic =
ImmutableSimpleExtension.VariadicBehavior.builder()
.min(1)
.parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT)
.build();

// All variadic arguments are i64 - should pass
assertTrue(
testInputTypesMatch(List.of(R.I64, R.I64, R.I64), args, variadic),
"Consistent variadic with same types should match");

// All variadic arguments are i64 (with different nullability) - should pass
assertTrue(
testInputTypesMatch(List.of(R.I64, R.I64, N.I64, R.I64), args, variadic),
"Consistent variadic with same types but different nullability should match");
}

@Test
void testConsistentVariadicWithDifferentTypes() {
// Function: test_func(i64, any...) with CONSISTENT parameterConsistency
List<SimpleExtension.Argument> args =
List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build());

SimpleExtension.VariadicBehavior variadic =
ImmutableSimpleExtension.VariadicBehavior.builder()
.min(1)
.parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT)
.build();

// Variadic arguments have different types - should fail
assertFalse(
testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic),
"Consistent variadic with different types should not match");

assertFalse(
testInputTypesMatch(List.of(R.I64, R.I32, R.I64), args, variadic),
"Consistent variadic with different types should not match");
}

@Test
void testInconsistentVariadicWithDifferentTypes() {
// Function: test_func(i64, any...) with INCONSISTENT parameterConsistency
// When INCONSISTENT, each variadic argument can be a different type, but they all need to
// match the parameterized type constraint (any in this case)
ParameterizedType anyType =
ParameterizedType.StringLiteral.builder().value("any").nullable(false).build();

List<SimpleExtension.Argument> args =
List.of(
SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build(),
SimpleExtension.ValueArgument.builder().value(anyType).name("variadic_arg").build());

SimpleExtension.VariadicBehavior variadic =
ImmutableSimpleExtension.VariadicBehavior.builder()
.min(1)
.parameterConsistency(
SimpleExtension.VariadicBehavior.ParameterConsistency.INCONSISTENT)
.build();

// Variadic arguments have different types - should pass with INCONSISTENT and wildcard type
assertTrue(
testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic),
"Inconsistent variadic with different types should match when using wildcard type");

assertTrue(
testInputTypesMatch(List.of(R.I64, R.I32, R.FP64, R.STRING), args, variadic),
"Inconsistent variadic with different types should match when using wildcard type");
}

@Test
void testConsistentVariadicWithWildcardType() {
// Function: test_func(any, any...) with CONSISTENT parameterConsistency
// The variadic arguments should all have the same concrete type
ParameterizedType anyType =
ParameterizedType.StringLiteral.builder().value("any").nullable(false).build();

List<SimpleExtension.Argument> args =
List.of(SimpleExtension.ValueArgument.builder().value(anyType).name("arg1").build());

SimpleExtension.VariadicBehavior variadic =
ImmutableSimpleExtension.VariadicBehavior.builder()
.min(1)
.parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT)
.build();

// All variadic arguments are i64 - should pass
assertTrue(
testInputTypesMatch(List.of(R.I64, R.I64, R.I64), args, variadic),
"Consistent variadic with wildcard type and same concrete types should match");

// Variadic arguments have different types - should fail
assertFalse(
testInputTypesMatch(List.of(R.I64, R.I64, R.FP64), args, variadic),
"Consistent variadic with wildcard type and different concrete types should not match");
}

@Test
void testConsistentVariadicWithMinGreaterThanOne() {
// Function: test_func(i64, i64...) with CONSISTENT and min=2
List<SimpleExtension.Argument> args =
List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build());

SimpleExtension.VariadicBehavior variadic =
ImmutableSimpleExtension.VariadicBehavior.builder()
.min(2)
.parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT)
.build();

// All variadic arguments are i64 - should pass
assertTrue(
testInputTypesMatch(List.of(R.I64, R.I64, R.I64, R.I64), args, variadic),
"Consistent variadic with min=2 and same types should match");

// Variadic arguments have different types - should fail
assertFalse(
testInputTypesMatch(List.of(R.I64, R.I64, R.I64, R.FP64), args, variadic),
"Consistent variadic with min=2 and different types should not match");
}

@Test
void testNoVariadicBehavior() {
// Function: test_func(i64) with no variadic behavior
List<SimpleExtension.Argument> args =
List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build());

// No variadic behavior - should pass regardless of consistency
assertTrue(
testInputTypesMatch(List.of(R.I64), args, null),
"No variadic behavior should always match");
}

@Test
void testConsistentVariadicWithNullableTypes() {
// Function: test_func(i64, i64...) with CONSISTENT parameterConsistency
List<SimpleExtension.Argument> args =
List.of(SimpleExtension.ValueArgument.builder().value(R.I64).name("arg1").build());

SimpleExtension.VariadicBehavior variadic =
ImmutableSimpleExtension.VariadicBehavior.builder()
.min(1)
.parameterConsistency(SimpleExtension.VariadicBehavior.ParameterConsistency.CONSISTENT)
.build();

// Mix of nullable and non-nullable i64 - should pass (nullability is ignored)
assertTrue(
testInputTypesMatch(List.of(R.I64, N.I64, R.I64, N.I64), args, variadic),
"Consistent variadic with same types but different nullability should match");
}
}
Loading