77import io .substrait .extension .SimpleExtension ;
88import io .substrait .type .Type ;
99import io .substrait .type .TypeExpressionEvaluator ;
10+ import java .util .Arrays ;
1011import java .util .Collections ;
1112import java .util .List ;
12- import java .util .function .Consumer ;
13+ import java .util .Map ;
14+ import java .util .function .Function ;
15+ import java .util .regex .Matcher ;
16+ import java .util .regex .Pattern ;
17+ import java .util .stream .Collectors ;
1318import java .util .stream .Stream ;
1419import org .apache .calcite .rel .type .RelDataType ;
1520import org .apache .calcite .rel .type .RelDataTypeFactory ;
1621import org .apache .calcite .runtime .CalciteException ;
1722import org .apache .calcite .runtime .Resources ;
1823import org .apache .calcite .sql .SqlOperator ;
1924import org .apache .calcite .sql .SqlOperatorBinding ;
25+ import org .apache .calcite .sql .type .SqlTypeName ;
2026import org .apache .calcite .sql .validate .SqlValidatorException ;
2127import org .junit .jupiter .params .ParameterizedTest ;
2228import org .junit .jupiter .params .provider .MethodSource ;
2531class SimpleExtensionToSqlOperatorTest {
2632
2733 private static final String CUSTOM_FUNCTION_PATH = "/extensions/scalar_functions_custom.yaml" ;
34+ private static final RelDataTypeFactory TYPE_FACTORY = SubstraitTypeSystem .TYPE_FACTORY ;
35+
36+ private static final Map <String , SimpleExtension .Function > FUNCTION_DEFS ;
37+ private static final Map <String , SqlOperator > OPERATORS ;
38+
39+ static {
40+ final var extensions =
41+ SimpleExtension .load (
42+ CUSTOM_FUNCTION_PATH ,
43+ SimpleExtensionToSqlOperatorTest .class .getResourceAsStream (CUSTOM_FUNCTION_PATH ));
44+
45+ FUNCTION_DEFS =
46+ extensions .scalarFunctions ().stream ()
47+ .collect (
48+ Collectors .toUnmodifiableMap (f -> f .name ().toLowerCase (), Function .identity ()));
49+
50+ OPERATORS =
51+ SimpleExtensionToSqlOperator .from (extensions ).stream ()
52+ .collect (
53+ Collectors .toUnmodifiableMap (
54+ op -> op .getName ().toLowerCase (), Function .identity ()));
55+ }
2856
29- private static final SimpleExtension .ExtensionCollection EXTENSIONS =
30- SimpleExtension .load (
31- CUSTOM_FUNCTION_PATH ,
32- SimpleExtensionToSqlOperatorTest .class .getResourceAsStream (CUSTOM_FUNCTION_PATH ));
33-
34- private static final List <SqlOperator > OPERATORS = SimpleExtensionToSqlOperator .from (EXTENSIONS );
35-
36- /** Data carrier for test cases. */
57+ /** Test Specification. */
3758 record TestSpec (
3859 String name ,
3960 int minArgs ,
4061 int maxArgs ,
4162 SimpleExtension .Nullability nullability ,
42- String expectedReturnType ,
43- Consumer <SqlOperator > customValidator ) {
44-
45- TestSpec (
46- final String name ,
47- final int min ,
48- final int max ,
49- final SimpleExtension .Nullability nullability ,
50- final String returnType ) {
51- this (name , min , max , nullability , returnType , op -> {});
52- }
53- }
63+ List <String > expectedArgTypes ) {}
5464
5565 @ ParameterizedTest
5666 @ MethodSource ("provideTestSpecs" )
5767 void testCustomUdfConversion (final TestSpec spec ) {
58- final SqlOperator operator = findOperator (spec .name );
59- final SimpleExtension .Function funcDef = findFunctionDef (spec .name );
68+ final SqlOperator operator = getOperator (spec .name );
69+ final SimpleExtension .Function funcDef = getFunctionDef (spec .name );
6070
61- // 1. Verify Argument Counts
6271 assertEquals (
6372 spec .minArgs ,
6473 operator .getOperandCountRange ().getMin (),
@@ -67,109 +76,151 @@ void testCustomUdfConversion(final TestSpec spec) {
6776 spec .maxArgs ,
6877 operator .getOperandCountRange ().getMax (),
6978 () -> spec .name + ": Incorrect max args" );
70- assertNotNull (operator .getOperandTypeChecker (), () -> spec .name + ": Type checker missing" );
7179
72- // 2. Verify Nullability (if specified)
7380 if (spec .nullability != null ) {
7481 assertEquals (
7582 spec .nullability , funcDef .nullability (), () -> spec .name + ": Incorrect nullability" );
7683 }
7784
78- // 3. Verify Return Type
79- verifyReturnType (operator , funcDef , spec .expectedReturnType );
85+ if (!spec .expectedArgTypes .isEmpty ()) {
86+ verifyAllowedSignatures (operator , spec .expectedArgTypes );
87+ }
8088
81- // 4. Custom Validation
82- spec .customValidator .accept (operator );
89+ verifyReturnTypeConsistency (operator , funcDef );
8390 }
8491
8592 private static Stream <TestSpec > provideTestSpecs () {
8693 return Stream .of (
94+ new TestSpec ("REGEXP_EXTRACT_CUSTOM" , 2 , 2 , null , List .of ("VARCHAR" , "VARCHAR" )),
95+ new TestSpec (
96+ "FORMAT_TEXT" , 2 , 2 , SimpleExtension .Nullability .MIRROR , List .of ("VARCHAR" , "VARCHAR" )),
97+ new TestSpec (
98+ "SYSTEM_PROPERTY_GET" ,
99+ 1 ,
100+ 1 ,
101+ SimpleExtension .Nullability .DECLARED_OUTPUT ,
102+ List .of ("VARCHAR" )),
87103 new TestSpec (
88- "REGEXP_EXTRACT_CUSTOM " ,
104+ "SAFE_DIVIDE_CUSTOM " ,
89105 2 ,
90106 2 ,
91- null ,
92- "VARCHAR" ,
93- op -> {
94- final String sigs =
95- op .getOperandTypeChecker ().getAllowedSignatures (op , op .getName ()).toLowerCase ();
96- // Calcite represents string families as <character>
97- assertTrue (
98- sigs .contains ("varchar" ) || sigs .contains ("string" ) || sigs .contains ("character" ),
99- () -> "Signatures should contain string types. Actual: " + sigs );
100- }),
101- new TestSpec ("FORMAT_TEXT" , 2 , 2 , SimpleExtension .Nullability .MIRROR , "VARCHAR" ),
102- new TestSpec (
103- "SYSTEM_PROPERTY_GET" , 1 , 1 , SimpleExtension .Nullability .DECLARED_OUTPUT , "VARCHAR" ),
104- new TestSpec ("SAFE_DIVIDE_CUSTOM" , 2 , 2 , SimpleExtension .Nullability .DISCRETE , "REAL" ));
107+ SimpleExtension .Nullability .DISCRETE ,
108+ List .of ("INTEGER" , "INTEGER" )));
105109 }
106110
107- private void verifyReturnType (
108- final SqlOperator operator ,
109- final SimpleExtension .Function funcDef ,
110- final String expectedTypeName ) {
111- assertNotNull (funcDef .returnType (), "Return type missing in YAML" );
112- assertNotNull (operator .getReturnTypeInference (), "SQL Operator missing return type inference" );
111+ /**
112+ * Parses the operator's signature string and checks that the types match the expected list
113+ * index-by-index.
114+ */
115+ private void verifyAllowedSignatures (
116+ final SqlOperator operator , final List <String > expectedArgTypes ) {
117+ assertNotNull (operator .getOperandTypeChecker (), "Operand type checker is null" );
113118
114- // 1. Evaluate expected type from YAML
115- final Type expectedType =
116- TypeExpressionEvaluator .evaluateExpression (
117- funcDef .returnType (), funcDef .args (), Collections .emptyList ());
119+ // e.g., "SAFE_DIVIDE_CUSTOM(<NUMERIC>, <NUMERIC>)"
120+ final String signature =
121+ operator
122+ .getOperandTypeChecker ()
123+ .getAllowedSignatures (operator , operator .getName ())
124+ .toUpperCase ();
125+
126+ // Regex to capture arguments inside parentheses: NAME(ARG1, ARG2)
127+ final Pattern pattern = Pattern .compile (".*?\\ ((.*)\\ ).*" );
128+ final Matcher matcher = pattern .matcher (signature );
129+
130+ assertTrue (matcher .matches (), () -> "Signature format not recognized: " + signature );
118131
119- // 2. Convert expected Substrait type to Calcite type
120- final RelDataType expectedCalciteType =
121- TypeConverter .DEFAULT .toCalcite (SubstraitTypeSystem .TYPE_FACTORY , expectedType );
132+ // Split args by comma (assuming simple types for this test suite)
133+ final String argsPart = matcher .group (1 );
134+ final List <String > actualArgTypes =
135+ Arrays .stream (argsPart .split ("," )).map (String ::trim ).toList ();
122136
123- // 3. Validate consistency: Ensure YAML derived type matches the TestSpec expectation string
124- // This utilizes the previously unused 'expectedTypeName'
125137 assertEquals (
126- expectedTypeName ,
127- expectedCalciteType .getSqlTypeName ().toString (),
128- () ->
129- "YAML definition derived type does not match TestSpec expectation for "
130- + funcDef .name ());
138+ expectedArgTypes .size (),
139+ actualArgTypes .size (),
140+ () -> "Signature argument count mismatch. Signature: " + signature );
141+
142+ // Positional Check
143+ for (int i = 0 ; i < expectedArgTypes .size (); i ++) {
144+ final String expected = expectedArgTypes .get (i );
145+ final String actual = actualArgTypes .get (i );
146+
147+ final SqlTypeName sqlTypeName = SqlTypeName .valueOf (expected );
148+ final String familyName = sqlTypeName .getFamily ().toString ();
149+
150+ // Check if the actual slot matches the specific type OR the generic family
151+ // e.g. Expected "INTEGER" matches actual "<NUMERIC>" or "INTEGER"
152+ final boolean match = actual .contains (expected ) || actual .contains (familyName );
153+
154+ final int index = i ;
155+ assertTrue (
156+ match ,
157+ () ->
158+ "Argument mismatch at index "
159+ + index
160+ + ".\n "
161+ + "Expected: "
162+ + expected
163+ + " (Family: "
164+ + familyName
165+ + ")\n "
166+ + "Actual: "
167+ + actual
168+ + "\n "
169+ + "Full Signature: "
170+ + signature );
171+ }
172+ }
173+
174+ private void verifyReturnTypeConsistency (
175+ final SqlOperator operator , final SimpleExtension .Function funcDef ) {
176+ assertNotNull (operator .getReturnTypeInference (), "Return type inference is null" );
177+
178+ // A. Expected: Evaluate YAML return type -> Convert to Calcite
179+ final Type yamlReturnType =
180+ TypeExpressionEvaluator .evaluateExpression (
181+ funcDef .returnType (), funcDef .args (), Collections .emptyList ());
182+ final RelDataType expectedType = TypeConverter .DEFAULT .toCalcite (TYPE_FACTORY , yamlReturnType );
131183
132- // 4. Infer actual type from the Calcite Operator using a minimal binding
133- final RelDataType actualReturnType =
134- operator .getReturnTypeInference ().inferReturnType (createMockBinding (operator ));
184+ // B. Actual: Infer from Operator (using empty binding, sufficient for static types)
185+ final RelDataType actualType =
186+ operator
187+ .getReturnTypeInference ()
188+ .inferReturnType (createMockBinding (operator , Collections .emptyList ()));
135189
136- // 5 . Compare Derived Expectation vs Actual Operator Inference
190+ // C . Compare
137191 assertEquals (
138- expectedCalciteType .getSqlTypeName (),
139- actualReturnType .getSqlTypeName (),
192+ expectedType .getSqlTypeName (),
193+ actualType .getSqlTypeName (),
140194 () -> "Return type mismatch for " + funcDef .name ());
141195 assertEquals (
142- expectedCalciteType .isNullable (),
143- actualReturnType .isNullable (),
196+ expectedType .isNullable (),
197+ actualType .isNullable (),
144198 () -> "Nullability mismatch for " + funcDef .name ());
145199 }
146200
147- private SqlOperator findOperator (final String name ) {
148- return OPERATORS .stream ()
149- .filter (o -> o .getName ().equalsIgnoreCase (name ))
150- .findFirst ()
151- .orElseThrow (() -> new AssertionError ("Operator not found: " + name ));
201+ private static SqlOperator getOperator (final String name ) {
202+ final SqlOperator op = OPERATORS .get (name .toLowerCase ());
203+ assertNotNull (op , "Operator not found: " + name );
204+ return op ;
152205 }
153206
154- private SimpleExtension .Function findFunctionDef (final String name ) {
155- return EXTENSIONS .scalarFunctions ().stream ()
156- .filter (f -> f .name ().equalsIgnoreCase (name ))
157- .findFirst ()
158- .orElseThrow (() -> new AssertionError ("YAML Definition not found: " + name ));
207+ private static SimpleExtension .Function getFunctionDef (final String name ) {
208+ final SimpleExtension .Function func = FUNCTION_DEFS .get (name .toLowerCase ());
209+ assertNotNull (func , "YAML Def not found: " + name );
210+ return func ;
159211 }
160212
161- /** Minimal anonymous implementation of SqlOperatorBinding to support return type inference. */
162- private SqlOperatorBinding createMockBinding (final SqlOperator operator ) {
163- final RelDataTypeFactory typeFactory = SubstraitTypeSystem .TYPE_FACTORY ;
164- return new SqlOperatorBinding (typeFactory , operator ) {
213+ private SqlOperatorBinding createMockBinding (
214+ final SqlOperator operator , final List <RelDataType > argumentTypes ) {
215+ return new SqlOperatorBinding (TYPE_FACTORY , operator ) {
165216 @ Override
166217 public int getOperandCount () {
167- return 0 ;
218+ return argumentTypes . size () ;
168219 }
169220
170221 @ Override
171222 public RelDataType getOperandType (final int ordinal ) {
172- throw new IndexOutOfBoundsException ( );
223+ return argumentTypes . get ( ordinal );
173224 }
174225
175226 @ Override
0 commit comments