Skip to content

Commit ec45173

Browse files
committed
chore(isthmus): check arg types in SimpleExtensionToSqlOperatorTest
1 parent 74fd2e7 commit ec45173

File tree

1 file changed

+140
-89
lines changed

1 file changed

+140
-89
lines changed

isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java

Lines changed: 140 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,22 @@
77
import io.substrait.extension.SimpleExtension;
88
import io.substrait.type.Type;
99
import io.substrait.type.TypeExpressionEvaluator;
10+
import java.util.Arrays;
1011
import java.util.Collections;
1112
import java.util.List;
12-
import java.util.function.Consumer;
13+
import java.util.Map;
14+
import java.util.function.Function;
15+
import java.util.regex.Matcher;
16+
import java.util.regex.Pattern;
17+
import java.util.stream.Collectors;
1318
import java.util.stream.Stream;
1419
import org.apache.calcite.rel.type.RelDataType;
1520
import org.apache.calcite.rel.type.RelDataTypeFactory;
1621
import org.apache.calcite.runtime.CalciteException;
1722
import org.apache.calcite.runtime.Resources;
1823
import org.apache.calcite.sql.SqlOperator;
1924
import org.apache.calcite.sql.SqlOperatorBinding;
25+
import org.apache.calcite.sql.type.SqlTypeName;
2026
import org.apache.calcite.sql.validate.SqlValidatorException;
2127
import org.junit.jupiter.params.ParameterizedTest;
2228
import org.junit.jupiter.params.provider.MethodSource;
@@ -25,40 +31,43 @@
2531
class SimpleExtensionToSqlOperatorTest {
2632

2733
private static final String CUSTOM_FUNCTION_PATH = "/extensions/scalar_functions_custom.yaml";
34+
private static final RelDataTypeFactory TYPE_FACTORY = SubstraitTypeSystem.TYPE_FACTORY;
35+
36+
private static final Map<String, SimpleExtension.Function> FUNCTION_DEFS;
37+
private static final Map<String, SqlOperator> OPERATORS;
38+
39+
static {
40+
final var extensions =
41+
SimpleExtension.load(
42+
CUSTOM_FUNCTION_PATH,
43+
SimpleExtensionToSqlOperatorTest.class.getResourceAsStream(CUSTOM_FUNCTION_PATH));
44+
45+
FUNCTION_DEFS =
46+
extensions.scalarFunctions().stream()
47+
.collect(
48+
Collectors.toUnmodifiableMap(f -> f.name().toLowerCase(), Function.identity()));
49+
50+
OPERATORS =
51+
SimpleExtensionToSqlOperator.from(extensions).stream()
52+
.collect(
53+
Collectors.toUnmodifiableMap(
54+
op -> op.getName().toLowerCase(), Function.identity()));
55+
}
2856

29-
private static final SimpleExtension.ExtensionCollection EXTENSIONS =
30-
SimpleExtension.load(
31-
CUSTOM_FUNCTION_PATH,
32-
SimpleExtensionToSqlOperatorTest.class.getResourceAsStream(CUSTOM_FUNCTION_PATH));
33-
34-
private static final List<SqlOperator> OPERATORS = SimpleExtensionToSqlOperator.from(EXTENSIONS);
35-
36-
/** Data carrier for test cases. */
57+
/** Test Specification. */
3758
record TestSpec(
3859
String name,
3960
int minArgs,
4061
int maxArgs,
4162
SimpleExtension.Nullability nullability,
42-
String expectedReturnType,
43-
Consumer<SqlOperator> customValidator) {
44-
45-
TestSpec(
46-
final String name,
47-
final int min,
48-
final int max,
49-
final SimpleExtension.Nullability nullability,
50-
final String returnType) {
51-
this(name, min, max, nullability, returnType, op -> {});
52-
}
53-
}
63+
List<String> expectedArgTypes) {}
5464

5565
@ParameterizedTest
5666
@MethodSource("provideTestSpecs")
5767
void testCustomUdfConversion(final TestSpec spec) {
58-
final SqlOperator operator = findOperator(spec.name);
59-
final SimpleExtension.Function funcDef = findFunctionDef(spec.name);
68+
final SqlOperator operator = getOperator(spec.name);
69+
final SimpleExtension.Function funcDef = getFunctionDef(spec.name);
6070

61-
// 1. Verify Argument Counts
6271
assertEquals(
6372
spec.minArgs,
6473
operator.getOperandCountRange().getMin(),
@@ -67,109 +76,151 @@ void testCustomUdfConversion(final TestSpec spec) {
6776
spec.maxArgs,
6877
operator.getOperandCountRange().getMax(),
6978
() -> spec.name + ": Incorrect max args");
70-
assertNotNull(operator.getOperandTypeChecker(), () -> spec.name + ": Type checker missing");
7179

72-
// 2. Verify Nullability (if specified)
7380
if (spec.nullability != null) {
7481
assertEquals(
7582
spec.nullability, funcDef.nullability(), () -> spec.name + ": Incorrect nullability");
7683
}
7784

78-
// 3. Verify Return Type
79-
verifyReturnType(operator, funcDef, spec.expectedReturnType);
85+
if (!spec.expectedArgTypes.isEmpty()) {
86+
verifyAllowedSignatures(operator, spec.expectedArgTypes);
87+
}
8088

81-
// 4. Custom Validation
82-
spec.customValidator.accept(operator);
89+
verifyReturnTypeConsistency(operator, funcDef);
8390
}
8491

8592
private static Stream<TestSpec> provideTestSpecs() {
8693
return Stream.of(
94+
new TestSpec("REGEXP_EXTRACT_CUSTOM", 2, 2, null, List.of("VARCHAR", "VARCHAR")),
95+
new TestSpec(
96+
"FORMAT_TEXT", 2, 2, SimpleExtension.Nullability.MIRROR, List.of("VARCHAR", "VARCHAR")),
97+
new TestSpec(
98+
"SYSTEM_PROPERTY_GET",
99+
1,
100+
1,
101+
SimpleExtension.Nullability.DECLARED_OUTPUT,
102+
List.of("VARCHAR")),
87103
new TestSpec(
88-
"REGEXP_EXTRACT_CUSTOM",
104+
"SAFE_DIVIDE_CUSTOM",
89105
2,
90106
2,
91-
null,
92-
"VARCHAR",
93-
op -> {
94-
final String sigs =
95-
op.getOperandTypeChecker().getAllowedSignatures(op, op.getName()).toLowerCase();
96-
// Calcite represents string families as <character>
97-
assertTrue(
98-
sigs.contains("varchar") || sigs.contains("string") || sigs.contains("character"),
99-
() -> "Signatures should contain string types. Actual: " + sigs);
100-
}),
101-
new TestSpec("FORMAT_TEXT", 2, 2, SimpleExtension.Nullability.MIRROR, "VARCHAR"),
102-
new TestSpec(
103-
"SYSTEM_PROPERTY_GET", 1, 1, SimpleExtension.Nullability.DECLARED_OUTPUT, "VARCHAR"),
104-
new TestSpec("SAFE_DIVIDE_CUSTOM", 2, 2, SimpleExtension.Nullability.DISCRETE, "REAL"));
107+
SimpleExtension.Nullability.DISCRETE,
108+
List.of("INTEGER", "INTEGER")));
105109
}
106110

107-
private void verifyReturnType(
108-
final SqlOperator operator,
109-
final SimpleExtension.Function funcDef,
110-
final String expectedTypeName) {
111-
assertNotNull(funcDef.returnType(), "Return type missing in YAML");
112-
assertNotNull(operator.getReturnTypeInference(), "SQL Operator missing return type inference");
111+
/**
112+
* Parses the operator's signature string and checks that the types match the expected list
113+
* index-by-index.
114+
*/
115+
private void verifyAllowedSignatures(
116+
final SqlOperator operator, final List<String> expectedArgTypes) {
117+
assertNotNull(operator.getOperandTypeChecker(), "Operand type checker is null");
113118

114-
// 1. Evaluate expected type from YAML
115-
final Type expectedType =
116-
TypeExpressionEvaluator.evaluateExpression(
117-
funcDef.returnType(), funcDef.args(), Collections.emptyList());
119+
// e.g., "SAFE_DIVIDE_CUSTOM(<NUMERIC>, <NUMERIC>)"
120+
final String signature =
121+
operator
122+
.getOperandTypeChecker()
123+
.getAllowedSignatures(operator, operator.getName())
124+
.toUpperCase();
125+
126+
// Regex to capture arguments inside parentheses: NAME(ARG1, ARG2)
127+
final Pattern pattern = Pattern.compile(".*?\\((.*)\\).*");
128+
final Matcher matcher = pattern.matcher(signature);
129+
130+
assertTrue(matcher.matches(), () -> "Signature format not recognized: " + signature);
118131

119-
// 2. Convert expected Substrait type to Calcite type
120-
final RelDataType expectedCalciteType =
121-
TypeConverter.DEFAULT.toCalcite(SubstraitTypeSystem.TYPE_FACTORY, expectedType);
132+
// Split args by comma (assuming simple types for this test suite)
133+
final String argsPart = matcher.group(1);
134+
final List<String> actualArgTypes =
135+
Arrays.stream(argsPart.split(",")).map(String::trim).toList();
122136

123-
// 3. Validate consistency: Ensure YAML derived type matches the TestSpec expectation string
124-
// This utilizes the previously unused 'expectedTypeName'
125137
assertEquals(
126-
expectedTypeName,
127-
expectedCalciteType.getSqlTypeName().toString(),
128-
() ->
129-
"YAML definition derived type does not match TestSpec expectation for "
130-
+ funcDef.name());
138+
expectedArgTypes.size(),
139+
actualArgTypes.size(),
140+
() -> "Signature argument count mismatch. Signature: " + signature);
141+
142+
// Positional Check
143+
for (int i = 0; i < expectedArgTypes.size(); i++) {
144+
final String expected = expectedArgTypes.get(i);
145+
final String actual = actualArgTypes.get(i);
146+
147+
final SqlTypeName sqlTypeName = SqlTypeName.valueOf(expected);
148+
final String familyName = sqlTypeName.getFamily().toString();
149+
150+
// Check if the actual slot matches the specific type OR the generic family
151+
// e.g. Expected "INTEGER" matches actual "<NUMERIC>" or "INTEGER"
152+
final boolean match = actual.contains(expected) || actual.contains(familyName);
153+
154+
final int index = i;
155+
assertTrue(
156+
match,
157+
() ->
158+
"Argument mismatch at index "
159+
+ index
160+
+ ".\n"
161+
+ "Expected: "
162+
+ expected
163+
+ " (Family: "
164+
+ familyName
165+
+ ")\n"
166+
+ "Actual: "
167+
+ actual
168+
+ "\n"
169+
+ "Full Signature: "
170+
+ signature);
171+
}
172+
}
173+
174+
private void verifyReturnTypeConsistency(
175+
final SqlOperator operator, final SimpleExtension.Function funcDef) {
176+
assertNotNull(operator.getReturnTypeInference(), "Return type inference is null");
177+
178+
// A. Expected: Evaluate YAML return type -> Convert to Calcite
179+
final Type yamlReturnType =
180+
TypeExpressionEvaluator.evaluateExpression(
181+
funcDef.returnType(), funcDef.args(), Collections.emptyList());
182+
final RelDataType expectedType = TypeConverter.DEFAULT.toCalcite(TYPE_FACTORY, yamlReturnType);
131183

132-
// 4. Infer actual type from the Calcite Operator using a minimal binding
133-
final RelDataType actualReturnType =
134-
operator.getReturnTypeInference().inferReturnType(createMockBinding(operator));
184+
// B. Actual: Infer from Operator (using empty binding, sufficient for static types)
185+
final RelDataType actualType =
186+
operator
187+
.getReturnTypeInference()
188+
.inferReturnType(createMockBinding(operator, Collections.emptyList()));
135189

136-
// 5. Compare Derived Expectation vs Actual Operator Inference
190+
// C. Compare
137191
assertEquals(
138-
expectedCalciteType.getSqlTypeName(),
139-
actualReturnType.getSqlTypeName(),
192+
expectedType.getSqlTypeName(),
193+
actualType.getSqlTypeName(),
140194
() -> "Return type mismatch for " + funcDef.name());
141195
assertEquals(
142-
expectedCalciteType.isNullable(),
143-
actualReturnType.isNullable(),
196+
expectedType.isNullable(),
197+
actualType.isNullable(),
144198
() -> "Nullability mismatch for " + funcDef.name());
145199
}
146200

147-
private SqlOperator findOperator(final String name) {
148-
return OPERATORS.stream()
149-
.filter(o -> o.getName().equalsIgnoreCase(name))
150-
.findFirst()
151-
.orElseThrow(() -> new AssertionError("Operator not found: " + name));
201+
private static SqlOperator getOperator(final String name) {
202+
final SqlOperator op = OPERATORS.get(name.toLowerCase());
203+
assertNotNull(op, "Operator not found: " + name);
204+
return op;
152205
}
153206

154-
private SimpleExtension.Function findFunctionDef(final String name) {
155-
return EXTENSIONS.scalarFunctions().stream()
156-
.filter(f -> f.name().equalsIgnoreCase(name))
157-
.findFirst()
158-
.orElseThrow(() -> new AssertionError("YAML Definition not found: " + name));
207+
private static SimpleExtension.Function getFunctionDef(final String name) {
208+
final SimpleExtension.Function func = FUNCTION_DEFS.get(name.toLowerCase());
209+
assertNotNull(func, "YAML Def not found: " + name);
210+
return func;
159211
}
160212

161-
/** Minimal anonymous implementation of SqlOperatorBinding to support return type inference. */
162-
private SqlOperatorBinding createMockBinding(final SqlOperator operator) {
163-
final RelDataTypeFactory typeFactory = SubstraitTypeSystem.TYPE_FACTORY;
164-
return new SqlOperatorBinding(typeFactory, operator) {
213+
private SqlOperatorBinding createMockBinding(
214+
final SqlOperator operator, final List<RelDataType> argumentTypes) {
215+
return new SqlOperatorBinding(TYPE_FACTORY, operator) {
165216
@Override
166217
public int getOperandCount() {
167-
return 0;
218+
return argumentTypes.size();
168219
}
169220

170221
@Override
171222
public RelDataType getOperandType(final int ordinal) {
172-
throw new IndexOutOfBoundsException();
223+
return argumentTypes.get(ordinal);
173224
}
174225

175226
@Override

0 commit comments

Comments
 (0)