diff --git a/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java new file mode 100644 index 000000000..377020bb3 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java @@ -0,0 +1,49 @@ +package io.substrait.isthmus; + +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.expression.FunctionMappings; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; + +public class ExtensionUtils { + + /** + * Extracts dynamic extensions from a collection of extensions. + * + *

A dynamic extension is a user-defined function (UDF) that is not part of the standard + * Substrait function catalog. These are custom functions that users define and provide at + * runtime, extending the built-in function set with domain-specific or application-specific + * operations. + * + *

This method filters out all functions that are already known to the Calcite operator table + * (the standard/built-in functions) and returns only the custom functions that represent new + * capabilities not available in the default function set. + * + *

Example: If a user defines a custom UDF "my_hash_function" that computes a + * proprietary hash, this would be a dynamic extension since it's not part of the standard + * Substrait specification. + * + * @param extensions the complete collection of extensions (both standard and custom) + * @return a new ExtensionCollection containing only the dynamic (custom/user-defined) functions + * that are not present in the standard Substrait function catalog + */ + public static SimpleExtension.ExtensionCollection getDynamicExtensions( + SimpleExtension.ExtensionCollection extensions) { + Set knownFunctionNames = + FunctionMappings.SCALAR_SIGS.stream() + .map(FunctionMappings.Sig::name) + .collect(Collectors.toSet()); + + List customFunctions = + extensions.scalarFunctions().stream() + .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase(Locale.ROOT))) + .collect(Collectors.toList()); + + return SimpleExtension.ExtensionCollection.builder() + .scalarFunctions(customFunctions) + // TODO: handle aggregates and other functions + .build(); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java b/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java index 8db29f73c..a54f24146 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java +++ b/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java @@ -17,4 +17,19 @@ public abstract class FeatureBoard { public Casing unquotedCasing() { return Casing.TO_UPPER; } + + /** + * Controls whether to support dynamic user-defined functions (UDFs) during SQL to Substrait plan + * conversion. + * + *

When enabled, custom functions defined in extension YAML files are available for use in SQL + * queries. These functions will be dynamically converted to SQL operators during plan conversion. + * This feature must be explicitly enabled by users and is disabled by default. + * + * @return true if dynamic UDFs should be supported; false otherwise (default) + */ + @Value.Default + public boolean allowDynamicUdfs() { + return false; + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java b/isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java new file mode 100644 index 000000000..3c61acd94 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java @@ -0,0 +1,342 @@ +package io.substrait.isthmus; + +import io.substrait.extension.SimpleExtension; +import io.substrait.function.ParameterizedType; +import io.substrait.function.ParameterizedTypeVisitor; +import io.substrait.function.TypeExpression; +import io.substrait.type.Type; +import io.substrait.type.TypeExpressionEvaluator; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; + +public final class SimpleExtensionToSqlOperator { + + private static final RelDataTypeFactory DEFAULT_TYPE_FACTORY = + new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM); + + private static final CalciteTypeVisitor CALCITE_TYPE_VISITOR = new CalciteTypeVisitor(); + + private SimpleExtensionToSqlOperator() {} + + public static List from(SimpleExtension.ExtensionCollection collection) { + return from(collection, DEFAULT_TYPE_FACTORY); + } + + public static List from( + SimpleExtension.ExtensionCollection collection, RelDataTypeFactory typeFactory) { + return from(collection, typeFactory, TypeConverter.DEFAULT); + } + + public static List from( + SimpleExtension.ExtensionCollection collection, + RelDataTypeFactory typeFactory, + TypeConverter typeConverter) { + // TODO: add support for windows functions + return Stream.concat( + collection.scalarFunctions().stream(), collection.aggregateFunctions().stream()) + .map(function -> toSqlFunction(function, typeFactory, typeConverter)) + .collect(Collectors.toList()); + } + + private static SqlFunction toSqlFunction( + SimpleExtension.Function function, + RelDataTypeFactory typeFactory, + TypeConverter typeConverter) { + + List argFamilies = new ArrayList<>(); + + for (SimpleExtension.Argument arg : function.requiredArguments()) { + if (arg instanceof SimpleExtension.ValueArgument) { + SimpleExtension.ValueArgument valueArg = (SimpleExtension.ValueArgument) arg; + SqlTypeName typeName = valueArg.value().accept(CALCITE_TYPE_VISITOR); + argFamilies.add(typeName.getFamily()); + } else if (arg instanceof SimpleExtension.EnumArgument) { + // Treat an EnumArgument as a required string literal. + argFamilies.add(SqlTypeFamily.STRING); + } + } + + SqlReturnTypeInference returnTypeInference = + new SubstraitReturnTypeInference(function, typeFactory, typeConverter); + + return new SqlFunction( + function.name(), + SqlKind.OTHER_FUNCTION, + returnTypeInference, + null, + OperandTypes.family(argFamilies), + SqlFunctionCategory.USER_DEFINED_FUNCTION); + } + + private static class SubstraitReturnTypeInference implements SqlReturnTypeInference { + + private final SimpleExtension.Function function; + private final RelDataTypeFactory typeFactory; + private final TypeConverter typeConverter; + + private SubstraitReturnTypeInference( + SimpleExtension.Function function, + RelDataTypeFactory typeFactory, + TypeConverter typeConverter) { + this.function = function; + this.typeFactory = typeFactory; + this.typeConverter = typeConverter; + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + List substraitArgTypes = + opBinding.collectOperandTypes().stream() + .map(typeConverter::toSubstrait) + .collect(Collectors.toList()); + + TypeExpression returnExpression = function.returnType(); + Type resolvedSubstraitType = + TypeExpressionEvaluator.evaluateExpression( + returnExpression, function.args(), substraitArgTypes); + + boolean finalIsNullable; + switch (function.nullability()) { + case MIRROR: + // If any input is nullable, the output is nullable. + finalIsNullable = + opBinding.collectOperandTypes().stream().anyMatch(RelDataType::isNullable); + break; + case DISCRETE: + case DECLARED_OUTPUT: + default: + // Use the nullability declared on the resolved Substrait type. + finalIsNullable = resolvedSubstraitType.nullable(); + break; + } + + RelDataType baseCalciteType = typeConverter.toCalcite(typeFactory, resolvedSubstraitType); + + return typeFactory.createTypeWithNullability(baseCalciteType, finalIsNullable); + } + } + + private static class CalciteTypeVisitor + extends ParameterizedTypeVisitor.ParameterizedTypeThrowsVisitor< + SqlTypeName, RuntimeException> { + + private CalciteTypeVisitor() { + super("Type not supported for Calcite conversion."); + } + + @Override + public SqlTypeName visit(Type.Bool expr) { + return SqlTypeName.BOOLEAN; + } + + @Override + public SqlTypeName visit(Type.I8 expr) { + return SqlTypeName.TINYINT; + } + + @Override + public SqlTypeName visit(Type.I16 expr) { + return SqlTypeName.SMALLINT; + } + + @Override + public SqlTypeName visit(Type.I32 expr) { + return SqlTypeName.INTEGER; + } + + @Override + public SqlTypeName visit(Type.I64 expr) { + return SqlTypeName.BIGINT; + } + + @Override + public SqlTypeName visit(Type.FP32 expr) { + return SqlTypeName.FLOAT; + } + + @Override + public SqlTypeName visit(Type.FP64 expr) { + return SqlTypeName.DOUBLE; + } + + @Override + public SqlTypeName visit(Type.Str expr) { + return SqlTypeName.VARCHAR; + } + + @Override + public SqlTypeName visit(Type.Binary expr) { + return SqlTypeName.VARBINARY; + } + + @Override + public SqlTypeName visit(Type.Date expr) { + return SqlTypeName.DATE; + } + + @Override + public SqlTypeName visit(Type.Time expr) { + return SqlTypeName.TIME; + } + + @Override + public SqlTypeName visit(Type.TimestampTZ expr) { + return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; + } + + @Override + public SqlTypeName visit(Type.Timestamp expr) { + return SqlTypeName.TIMESTAMP; + } + + @Override + public SqlTypeName visit(Type.IntervalYear year) { + return SqlTypeName.INTERVAL_YEAR_MONTH; + } + + @Override + public SqlTypeName visit(Type.IntervalDay day) { + return SqlTypeName.INTERVAL_DAY; + } + + @Override + public SqlTypeName visit(Type.UUID expr) { + return SqlTypeName.UUID; + } + + @Override + public SqlTypeName visit(Type.Struct struct) { + return SqlTypeName.ROW; + } + + @Override + public SqlTypeName visit(Type.ListType listType) { + return SqlTypeName.ARRAY; + } + + @Override + public SqlTypeName visit(Type.Map map) { + return SqlTypeName.MAP; + } + + @Override + public SqlTypeName visit(ParameterizedType.FixedChar expr) { + return SqlTypeName.CHAR; + } + + @Override + public SqlTypeName visit(ParameterizedType.VarChar expr) { + return SqlTypeName.VARCHAR; + } + + @Override + public SqlTypeName visit(ParameterizedType.FixedBinary expr) { + return SqlTypeName.BINARY; + } + + @Override + public SqlTypeName visit(ParameterizedType.Decimal expr) { + return SqlTypeName.DECIMAL; + } + + @Override + public SqlTypeName visit(ParameterizedType.Struct expr) { + return SqlTypeName.ROW; + } + + @Override + public SqlTypeName visit(ParameterizedType.ListType expr) { + return SqlTypeName.ARRAY; + } + + @Override + public SqlTypeName visit(ParameterizedType.Map expr) { + return SqlTypeName.MAP; + } + + @Override + public SqlTypeName visit(ParameterizedType.PrecisionTimestamp expr) { + return SqlTypeName.TIMESTAMP; + } + + @Override + public SqlTypeName visit(ParameterizedType.PrecisionTimestampTZ expr) { + return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; + } + + @Override + public SqlTypeName visit(ParameterizedType.PrecisionTime expr) { + return SqlTypeName.TIME; + } + + @Override + public SqlTypeName visit(ParameterizedType.IntervalDay expr) { + return SqlTypeName.INTERVAL_DAY; + } + + @Override + public SqlTypeName visit(ParameterizedType.StringLiteral expr) { + String type = expr.value().toUpperCase(); + + if (type.startsWith("ANY")) { + return SqlTypeName.ANY; + } + + switch (type) { + case "BOOLEAN": + return SqlTypeName.BOOLEAN; + case "I8": + return SqlTypeName.TINYINT; + case "I16": + return SqlTypeName.SMALLINT; + case "I32": + return SqlTypeName.INTEGER; + case "I64": + return SqlTypeName.BIGINT; + case "FP32": + return SqlTypeName.FLOAT; + case "FP64": + return SqlTypeName.DOUBLE; + case "STRING": + return SqlTypeName.VARCHAR; + case "BINARY": + return SqlTypeName.VARBINARY; + case "TIMESTAMP": + return SqlTypeName.TIMESTAMP; + case "TIMESTAMP_TZ": + return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; + case "DATE": + return SqlTypeName.DATE; + case "TIME": + return SqlTypeName.TIME; + case "UUID": + return SqlTypeName.UUID; + default: + if (type.startsWith("DECIMAL")) { + return SqlTypeName.DECIMAL; + } + if (type.startsWith("STRUCT")) { + return SqlTypeName.ROW; + } + if (type.startsWith("LIST")) { + return SqlTypeName.ARRAY; + } + return super.visit(expr); + } + } + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 671deabe5..f667deab0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -19,8 +19,7 @@ import org.apache.calcite.sql2rel.SqlToRelConverter; public class SqlConverterBase { - protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = - DefaultExtensionCatalog.DEFAULT_COLLECTION; + protected final SimpleExtension.ExtensionCollection extensionCollection; public static final CalciteConnectionConfig CONNECTION_CONFIG = CalciteConnectionConfig.DEFAULT.set( @@ -36,7 +35,8 @@ public class SqlConverterBase { protected static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build(); final FeatureBoard featureBoard; - protected SqlConverterBase(FeatureBoard features) { + protected SqlConverterBase( + FeatureBoard features, SimpleExtension.ExtensionCollection extensionCollection) { this.factory = SubstraitTypeSystem.TYPE_FACTORY; this.config = CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false"); @@ -55,5 +55,11 @@ protected SqlConverterBase(FeatureBoard features) { .withUnquotedCasing(featureBoard.unquotedCasing()) .withParserFactory(SqlDdlParserImpl.FACTORY) .withConformance(SqlConformanceEnum.LENIENT); + + this.extensionCollection = extensionCollection; + } + + protected SqlConverterBase(FeatureBoard features) { + this(features, DefaultExtensionCatalog.DEFAULT_COLLECTION); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index c32fab07c..3d45f8bde 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -3,6 +3,7 @@ import io.substrait.extendedexpression.ExtendedExpression; import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; import io.substrait.extendedexpression.ImmutableExtendedExpression.Builder; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.calcite.SubstraitTable; import io.substrait.isthmus.expression.RexExpressionConverter; @@ -34,12 +35,12 @@ public class SqlExpressionToSubstrait extends SqlConverterBase { protected final RexExpressionConverter rexConverter; public SqlExpressionToSubstrait() { - this(FEATURES_DEFAULT, EXTENSION_COLLECTION); + this(FEATURES_DEFAULT, DefaultExtensionCatalog.DEFAULT_COLLECTION); } public SqlExpressionToSubstrait( FeatureBoard features, SimpleExtension.ExtensionCollection extensions) { - super(features); + super(features, extensions); ScalarFunctionConverter scalarFunctionConverter = new ScalarFunctionConverter(extensions.scalarFunctions(), factory); this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 3e19ca58c..abe935a75 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,22 +1,49 @@ package io.substrait.isthmus; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.ImmutablePlan.Builder; import io.substrait.plan.Plan; import io.substrait.plan.Plan.Version; import io.substrait.plan.PlanProtoConverter; +import java.util.List; import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.util.SqlOperatorTables; /** Take a SQL statement and a set of table definitions and return a substrait plan. */ public class SqlToSubstrait extends SqlConverterBase { + private final SqlOperatorTable operatorTable; public SqlToSubstrait() { - this(null); + this(DefaultExtensionCatalog.DEFAULT_COLLECTION, null); } public SqlToSubstrait(FeatureBoard features) { - super(features); + this(DefaultExtensionCatalog.DEFAULT_COLLECTION, features); + } + + public SqlToSubstrait(SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { + super(features, extensions); + + if (featureBoard.allowDynamicUdfs()) { + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + ExtensionUtils.getDynamicExtensions(extensions); + if (!dynamicExtensionCollection.scalarFunctions().isEmpty() + || !dynamicExtensionCollection.aggregateFunctions().isEmpty()) { + List generatedDynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, this.factory); + this.operatorTable = + SqlOperatorTables.chain( + SubstraitOperatorTable.INSTANCE, SqlOperatorTables.of(generatedDynamicOperators)); + return; + } + } + this.operatorTable = SubstraitOperatorTable.INSTANCE; } /** @@ -53,8 +80,8 @@ public Plan convert(String sqlStatements, Prepare.CatalogReader catalogReader) builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build()); // TODO: consider case in which one sql passes conversion while others don't - SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader).stream() - .map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard)) + SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader, operatorTable).stream() + .map(root -> SubstraitRelVisitor.convert(root, extensionCollection, featureBoard)) .forEach(root -> builder.addRoots(root)); return builder.build(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 96e091578..47daf97e2 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -1,7 +1,5 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; - import com.google.common.collect.ImmutableList; import com.google.common.collect.Range; import com.google.common.collect.RangeMap; @@ -14,6 +12,7 @@ import io.substrait.isthmus.calcite.rel.CreateView; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.ExpressionRexConverter; +import io.substrait.isthmus.expression.FunctionMappings; import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.relation.AbstractDdlRel; @@ -110,10 +109,18 @@ public SubstraitRelNodeConverter( SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory, RelBuilder relBuilder) { + this(extensions, typeFactory, relBuilder, ImmutableFeatureBoard.builder().build()); + } + + public SubstraitRelNodeConverter( + SimpleExtension.ExtensionCollection extensions, + RelDataTypeFactory typeFactory, + RelBuilder relBuilder, + FeatureBoard featureBoard) { this( typeFactory, relBuilder, - new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory), + createScalarFunctionConverter(extensions, typeFactory, featureBoard.allowDynamicUdfs()), new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory), new WindowFunctionConverter(extensions.windowFunctions(), typeFactory), TypeConverter.DEFAULT); @@ -155,11 +162,68 @@ public SubstraitRelNodeConverter( this.expressionRexConverter.setRelNodeConverter(this); } + private static ScalarFunctionConverter createScalarFunctionConverter( + SimpleExtension.ExtensionCollection extensions, + RelDataTypeFactory typeFactory, + boolean allowDynamicUdfs) { + + List additionalSignatures; + + if (allowDynamicUdfs) { + java.util.Set knownFunctionNames = + FunctionMappings.SCALAR_SIGS.stream() + .map(FunctionMappings.Sig::name) + .collect(Collectors.toSet()); + + List dynamicFunctions = + extensions.scalarFunctions().stream() + .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase())) + .collect(Collectors.toList()); + + if (dynamicFunctions.isEmpty()) { + additionalSignatures = Collections.emptyList(); + } else { + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + SimpleExtension.ExtensionCollection.builder().scalarFunctions(dynamicFunctions).build(); + + List dynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + + additionalSignatures = + dynamicOperators.stream() + .map(op -> FunctionMappings.s(op, op.getName())) + .collect(Collectors.toList()); + } + } else { + additionalSignatures = Collections.emptyList(); + } + + return new ScalarFunctionConverter( + extensions.scalarFunctions(), additionalSignatures, typeFactory, TypeConverter.DEFAULT); + } + public static RelNode convert( Rel relRoot, RelOptCluster relOptCluster, Prepare.CatalogReader catalogReader, - SqlParser.Config parserConfig) { + SqlParser.Config parserConfig, + SimpleExtension.ExtensionCollection extensions) { + return convert( + relRoot, + relOptCluster, + catalogReader, + parserConfig, + extensions, + ImmutableFeatureBoard.builder().build()); + } + + public static RelNode convert( + Rel relRoot, + RelOptCluster relOptCluster, + Prepare.CatalogReader catalogReader, + SqlParser.Config parserConfig, + SimpleExtension.ExtensionCollection extensions, + FeatureBoard featureBoard) { RelBuilder relBuilder = RelBuilder.create( Frameworks.newConfigBuilder() @@ -171,7 +235,7 @@ public static RelNode convert( return relRoot.accept( new SubstraitRelNodeConverter( - EXTENSION_COLLECTION, relOptCluster.getTypeFactory(), relBuilder), + extensions, relOptCluster.getTypeFactory(), relBuilder, featureBoard), Context.newContext()); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 3dde72c13..5c3bb7a72 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -9,6 +9,7 @@ import io.substrait.isthmus.calcite.rel.CreateView; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.CallConverters; +import io.substrait.isthmus.expression.FunctionMappings; import io.substrait.isthmus.expression.LiteralConverter; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; @@ -62,6 +63,7 @@ import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; import org.immutables.value.Value; @@ -88,10 +90,31 @@ public SubstraitRelVisitor( RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { + this.typeConverter = TypeConverter.DEFAULT; - ArrayList converters = new ArrayList(); + ArrayList converters = new ArrayList<>(); converters.addAll(CallConverters.defaults(typeConverter)); - converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)); + + if (features.allowDynamicUdfs()) { + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + ExtensionUtils.getDynamicExtensions(extensions); + List dynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + + List additionalSignatures = + dynamicOperators.stream() + .map(op -> FunctionMappings.s(op, op.getName())) + .collect(Collectors.toList()); + converters.add( + new ScalarFunctionConverter( + extensions.scalarFunctions(), + additionalSignatures, + typeFactory, + TypeConverter.DEFAULT)); + } else { + converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)); + } + converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); this.aggregateFunctionConverter = new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index 8dcfbf9e0..772a3e192 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -38,6 +38,7 @@ public class SubstraitToCalcite { protected final RelDataTypeFactory typeFactory; protected final TypeConverter typeConverter; protected final Prepare.CatalogReader catalogReader; + protected final FeatureBoard featureBoard; public SubstraitToCalcite( SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { @@ -63,10 +64,25 @@ public SubstraitToCalcite( RelDataTypeFactory typeFactory, TypeConverter typeConverter, Prepare.CatalogReader catalogReader) { + this( + extensions, + typeFactory, + typeConverter, + catalogReader, + ImmutableFeatureBoard.builder().build()); + } + + public SubstraitToCalcite( + SimpleExtension.ExtensionCollection extensions, + RelDataTypeFactory typeFactory, + TypeConverter typeConverter, + Prepare.CatalogReader catalogReader, + FeatureBoard featureBoard) { this.extensions = extensions; this.typeFactory = typeFactory; this.typeConverter = typeConverter; this.catalogReader = catalogReader; + this.featureBoard = featureBoard; } /** @@ -94,7 +110,7 @@ protected RelBuilder createRelBuilder(CalciteSchema schema) { *

Override this method to customize the {@link SubstraitRelNodeConverter}. */ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { - return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder); + return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder, featureBoard); } /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index 421b45317..e327ab007 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import io.substrait.extension.SimpleExtension; import io.substrait.relation.Rel; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; @@ -10,7 +11,12 @@ public SubstraitToSql() { super(FEATURES_DEFAULT); } + public SubstraitToSql(SimpleExtension.ExtensionCollection extensions) { + super(FEATURES_DEFAULT, extensions); + } + public RelNode substraitRelToCalciteRel(Rel relRoot, Prepare.CatalogReader catalog) { - return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, catalog, parserConfig); + return SubstraitRelNodeConverter.convert( + relRoot, relOptCluster, catalog, parserConfig, extensionCollection); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index d738ef157..b5604d4d9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -25,7 +25,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.IdentityHashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -129,8 +128,7 @@ public FunctionConverter( .collect( Multimaps.toMultimap( FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create)); - IdentityHashMap matcherMap = - new IdentityHashMap(); + Map matcherMap = new HashMap<>(); for (String key : nameToFn.keySet()) { Collection sigs = calciteOperators.get(key); if (sigs.isEmpty()) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java index a87e29563..b46d30e7c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java @@ -15,6 +15,7 @@ import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql2rel.SqlToRelConverter; @@ -41,6 +42,23 @@ public static RelRoot convertQuery(String sqlStatement, Prepare.CatalogReader ca return convertQuery(sqlStatement, catalogReader, validator, createDefaultRelOptCluster()); } + /** + * Converts a SQL statement to a Calcite {@link RelRoot}. + * + * @param sqlStatement a SQL statement string + * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in + * the SQL statement + * @param operatorTable the {@link SqlOperatorTable} for controlling valid operators + * @return a {@link RelRoot} corresponding to the given SQL statement + * @throws SqlParseException if there is an error while parsing the SQL statement + */ + public static RelRoot convertQuery( + String sqlStatement, Prepare.CatalogReader catalogReader, SqlOperatorTable operatorTable) + throws SqlParseException { + SqlValidator validator = new SubstraitSqlValidator(catalogReader, operatorTable); + return convertQuery(sqlStatement, catalogReader, validator, createDefaultRelOptCluster()); + } + /** * Converts a SQL statement to a Calcite {@link RelRoot}. * @@ -72,6 +90,24 @@ public static RelRoot convertQuery( return relRoots.get(0); } + /** + * Converts one or more SQL statements to a List of {@link RelRoot}, with one {@link RelRoot} per + * statement. + * + * @param sqlStatements a string containing one or more SQL statements + * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in + * the SQL statements + * @param operatorTable the {@link SqlOperatorTable} for controlling valid operators + * @return a list of {@link RelRoot}s corresponding to the given SQL statements + * @throws SqlParseException if there is an error while parsing the SQL statements + */ + public static List convertQueries( + String sqlStatements, Prepare.CatalogReader catalogReader, SqlOperatorTable operatorTable) + throws SqlParseException { + SqlValidator validator = new SubstraitSqlValidator(catalogReader, operatorTable); + return convertQueries(sqlStatements, catalogReader, validator, createDefaultRelOptCluster()); + } + /** * Converts one or more SQL statements to a List of {@link RelRoot}, with one {@link RelRoot} per * statement. diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java index 52be2d6a5..07b6edda8 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java @@ -2,6 +2,7 @@ import io.substrait.isthmus.calcite.SubstraitOperatorTable; import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorImpl; @@ -12,4 +13,8 @@ public class SubstraitSqlValidator extends SqlValidatorImpl { public SubstraitSqlValidator(Prepare.CatalogReader catalogReader) { super(SubstraitOperatorTable.INSTANCE, catalogReader, catalogReader.getTypeFactory(), CONFIG); } + + public SubstraitSqlValidator(Prepare.CatalogReader catalogReader, SqlOperatorTable opTable) { + super(opTable, catalogReader, catalogReader.getTypeFactory(), CONFIG); + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java index e21d3b653..94f77c6a9 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java @@ -1,6 +1,5 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static io.substrait.isthmus.SubstraitTypeSystem.YEAR_MONTH_INTERVAL; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -11,6 +10,8 @@ import io.substrait.expression.Expression.Literal; import io.substrait.expression.Expression.TimestampLiteral; import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.SubstraitRelNodeConverter.Context; import io.substrait.isthmus.expression.ExpressionRexConverter; import io.substrait.isthmus.expression.RexExpressionConverter; @@ -35,6 +36,8 @@ import org.junit.jupiter.api.Test; class CalciteLiteralTest extends CalciteObjs { + protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + DefaultExtensionCatalog.DEFAULT_COLLECTION; private final ScalarFunctionConverter scalarFunctionConverter = new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), type); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java index dc3cfdbbc..7a0d4ddc5 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java @@ -1,12 +1,13 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.ImmutableAggregateFunctionInvocation; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; import io.substrait.relation.Aggregate; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; @@ -17,6 +18,8 @@ import org.junit.jupiter.api.Test; class ComplexAggregateTest extends PlanTestBase { + protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + DefaultExtensionCatalog.DEFAULT_COLLECTION; final TypeCreator R = TypeCreator.of(false); SubstraitBuilder b = new SubstraitBuilder(extensions); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java index afe088d30..300786f21 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java @@ -1,10 +1,11 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; import io.substrait.relation.Rel; import io.substrait.type.TypeCreator; import java.io.PrintWriter; @@ -20,6 +21,9 @@ class ComplexSortTest extends PlanTestBase { + private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + DefaultExtensionCatalog.DEFAULT_COLLECTION; + final TypeCreator R = TypeCreator.of(false); SubstraitBuilder b = new SubstraitBuilder(extensions); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index ed63f0c47..a99c5ff35 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -1,8 +1,9 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.Plan; @@ -13,6 +14,9 @@ class NameRoundtripTest extends PlanTestBase { + private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + DefaultExtensionCatalog.DEFAULT_COLLECTION; + @Test void preserveNamesFromSql() throws Exception { String createStatement = "CREATE TABLE foo(a BIGINT, b BIGINT)"; diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index b408303a0..8f392aae2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -1,8 +1,9 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import java.io.IOException; import org.apache.calcite.plan.hep.HepPlanner; @@ -16,6 +17,9 @@ class OptimizerIntegrationTest extends PlanTestBase { + private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + DefaultExtensionCatalog.DEFAULT_COLLECTION; + @Test void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOException { String query = diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 7cf5c38c5..a37916bc4 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -8,6 +8,7 @@ import com.google.common.base.Charsets; import com.google.common.io.Resources; import io.substrait.dsl.SubstraitBuilder; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; @@ -38,12 +39,12 @@ import org.apache.calcite.tools.RelBuilder; public class PlanTestBase { - protected final SimpleExtension.ExtensionCollection extensions = - SqlConverterBase.EXTENSION_COLLECTION; + protected final SimpleExtension.ExtensionCollection extensions; + protected final RelCreator creator = new RelCreator(); protected final RelBuilder builder = creator.createRelBuilder(); protected final RelDataTypeFactory typeFactory = creator.typeFactory(); - protected final SubstraitBuilder substraitBuilder = new SubstraitBuilder(extensions); + protected final SubstraitBuilder substraitBuilder; protected static final TypeCreator R = TypeCreator.of(false); protected static final TypeCreator N = TypeCreator.of(true); @@ -63,6 +64,15 @@ public class PlanTestBase { protected static CalciteCatalogReader TPCDS_CATALOG = PlanTestBase.schemaToCatalog("tpcds", TPCDS_SCHEMA); + protected PlanTestBase() { + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + protected PlanTestBase(SimpleExtension.ExtensionCollection extensions) { + this.extensions = extensions; + this.substraitBuilder = new SubstraitBuilder(extensions); + } + public static String asString(String resource) throws IOException { return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); } @@ -142,6 +152,76 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( return relRoot2; } + /** + * Verifies that the given query can be converted through multiple round trips, with loose POJO + * comparison. + * + *

"Loose" here means not comparing the initial POJO (from SQL→Substrait conversion) to the + * first POJO after the round trip (from Substrait→Calcite→Substrait conversion), due to optimizer + * differences between: + * + *

+ * + *

Instead, this method compares the second and third round-trip POJOs, ensuring that + * subsequent round trips produce stable results. + * + * @param query the SQL query to test + * @param catalogReader the Calcite catalog with table definitions + * @param featureBoard optional FeatureBoard to control conversion behavior (e.g., dynamic UDFs). + * If null, a default FeatureBoard is used. + */ + protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( + String query, Prepare.CatalogReader catalogReader, FeatureBoard featureBoard) + throws Exception { + // Use provided FeatureBoard, or create default if null + FeatureBoard features = + featureBoard != null ? featureBoard : ImmutableFeatureBoard.builder().build(); + + SubstraitToCalcite substraitToCalcite = + new SubstraitToCalcite(extensions, typeFactory, TypeConverter.DEFAULT, null, features); + SqlToSubstrait s = new SqlToSubstrait(extensions, features); + + // 1. SQL -> Substrait Plan + Plan plan1 = s.convert(query, catalogReader); + + // 2. Substrait Plan -> Substrait Root (POJO 1) + Plan.Root pojo1 = plan1.getRoots().get(0); + + // 3. Substrait Root -> Calcite RelNode + RelRoot relRoot2 = substraitToCalcite.convert(pojo1); + + // 4. Calcite RelNode -> Substrait Root (POJO 2) + Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, extensions, features); + + // Note: pojo1 and pojo2 may differ due to different optimization strategies applied by: + // - SqlNode->RelRoot conversion during SQL->Substrait conversion + // - RelBuilder/RexBuilder optimization during Substrait->Calcite conversion + // This is expected, so we don't compare pojo1 and pojo2. + + // 5. Substrait Root 2 -> Calcite RelNode + RelRoot relRoot3 = substraitToCalcite.convert(pojo2); + + // 6. Calcite RelNode -> Substrait Root (POJO 3) + Plan.Root pojo3 = SubstraitRelVisitor.convert(relRoot3, extensions, features); + + // Verify that subsequent round trips are stable (pojo2 and pojo3 should be identical) + assertEquals(pojo2, pojo3); + return relRoot2; + } + + /** + * Convenience overload of {@link #assertSqlSubstraitRelRoundTripLoosePojoComparison(String, + * Prepare.CatalogReader, FeatureBoard)} with default FeatureBoard behavior (no dynamic UDFs). + */ + protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( + String query, Prepare.CatalogReader catalogReader) throws Exception { + return assertSqlSubstraitRelRoundTripLoosePojoComparison( + query, catalogReader, ImmutableFeatureBoard.builder().build()); + } + @Beta protected void assertFullRoundTrip(String query) throws SqlParseException { assertFullRoundTrip(query, TPCH_CATALOG); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java new file mode 100644 index 000000000..030e501d8 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java @@ -0,0 +1,232 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.extension.SimpleExtension; +import io.substrait.type.Type; +import io.substrait.type.TypeExpressionEvaluator; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.runtime.CalciteException; +import org.apache.calcite.runtime.Resources; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidatorException; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +/** Tests for conversion of SimpleExtension function definitions to Calcite SqlOperators. */ +class SimpleExtensionToSqlOperatorTest { + + private static final String CUSTOM_FUNCTION_PATH = "/extensions/scalar_functions_custom.yaml"; + private static final RelDataTypeFactory TYPE_FACTORY = SubstraitTypeSystem.TYPE_FACTORY; + + private static final Map FUNCTION_DEFS; + private static final Map OPERATORS; + + static { + final SimpleExtension.ExtensionCollection extensions = + SimpleExtension.load( + CUSTOM_FUNCTION_PATH, + SimpleExtensionToSqlOperatorTest.class.getResourceAsStream(CUSTOM_FUNCTION_PATH)); + + FUNCTION_DEFS = + extensions.scalarFunctions().stream() + .collect( + Collectors.toUnmodifiableMap(f -> f.name().toLowerCase(), Function.identity())); + + OPERATORS = + SimpleExtensionToSqlOperator.from(extensions).stream() + .collect( + Collectors.toUnmodifiableMap( + op -> op.getName().toLowerCase(), Function.identity())); + } + + /** Test Specification. */ + record TestSpec( + String name, + int minArgs, + int maxArgs, + SimpleExtension.Nullability nullability, + List expectedArgTypes) {} + + @ParameterizedTest + @MethodSource("provideTestSpecs") + void testCustomUdfConversion(final TestSpec spec) { + final SqlOperator operator = getOperator(spec.name); + final SimpleExtension.Function funcDef = getFunctionDef(spec.name); + + assertEquals( + spec.minArgs, + operator.getOperandCountRange().getMin(), + () -> spec.name + ": Incorrect min args"); + assertEquals( + spec.maxArgs, + operator.getOperandCountRange().getMax(), + () -> spec.name + ": Incorrect max args"); + + if (spec.nullability != null) { + assertEquals( + spec.nullability, funcDef.nullability(), () -> spec.name + ": Incorrect nullability"); + } + + if (!spec.expectedArgTypes.isEmpty()) { + verifyAllowedSignatures(operator, spec.expectedArgTypes); + } + + verifyReturnTypeConsistency(operator, funcDef); + } + + private static Stream provideTestSpecs() { + return Stream.of( + new TestSpec("REGEXP_EXTRACT_CUSTOM", 2, 2, null, List.of("VARCHAR", "VARCHAR")), + new TestSpec( + "FORMAT_TEXT", 2, 2, SimpleExtension.Nullability.MIRROR, List.of("VARCHAR", "VARCHAR")), + new TestSpec( + "SYSTEM_PROPERTY_GET", + 1, + 1, + SimpleExtension.Nullability.DECLARED_OUTPUT, + List.of("VARCHAR")), + new TestSpec( + "SAFE_DIVIDE_CUSTOM", + 2, + 2, + SimpleExtension.Nullability.DISCRETE, + List.of("INTEGER", "INTEGER"))); + } + + /** + * Parses the operator's signature string and checks that the types match the expected list + * index-by-index. + */ + private void verifyAllowedSignatures( + final SqlOperator operator, final List expectedArgTypes) { + assertNotNull(operator.getOperandTypeChecker(), "Operand type checker is null"); + + // e.g., "SAFE_DIVIDE_CUSTOM(, )" + final String signature = + operator + .getOperandTypeChecker() + .getAllowedSignatures(operator, operator.getName()) + .toUpperCase(); + + // Regex to capture arguments inside parentheses: NAME(ARG1, ARG2) + final Pattern pattern = Pattern.compile(".*?\\((.*)\\).*"); + final Matcher matcher = pattern.matcher(signature); + + assertTrue(matcher.matches(), () -> "Signature format not recognized: " + signature); + + // Split args by comma (assuming simple types for this test suite) + final String argsPart = matcher.group(1); + final List actualArgTypes = + Arrays.stream(argsPart.split(",")).map(String::trim).toList(); + + assertEquals( + expectedArgTypes.size(), + actualArgTypes.size(), + () -> "Signature argument count mismatch. Signature: " + signature); + + // Positional Check + for (int i = 0; i < expectedArgTypes.size(); i++) { + final String expected = expectedArgTypes.get(i); + final String actual = actualArgTypes.get(i); + + final SqlTypeName sqlTypeName = SqlTypeName.valueOf(expected); + final String familyName = sqlTypeName.getFamily().toString(); + + // Check if the actual slot matches the specific type OR the generic family + // e.g. Expected "INTEGER" matches actual "" or "INTEGER" + final boolean match = actual.contains(expected) || actual.contains(familyName); + + final int index = i; + assertTrue( + match, + () -> + "Argument mismatch at index " + + index + + ".\n" + + "Expected: " + + expected + + " (Family: " + + familyName + + ")\n" + + "Actual: " + + actual + + "\n" + + "Full Signature: " + + signature); + } + } + + private void verifyReturnTypeConsistency( + final SqlOperator operator, final SimpleExtension.Function funcDef) { + assertNotNull(operator.getReturnTypeInference(), "Return type inference is null"); + + // A. Expected: Evaluate YAML return type -> Convert to Calcite + final Type yamlReturnType = + TypeExpressionEvaluator.evaluateExpression( + funcDef.returnType(), funcDef.args(), Collections.emptyList()); + final RelDataType expectedType = TypeConverter.DEFAULT.toCalcite(TYPE_FACTORY, yamlReturnType); + + // B. Actual: Infer from Operator (using empty binding, sufficient for static types) + final RelDataType actualType = + operator + .getReturnTypeInference() + .inferReturnType(createMockBinding(operator, Collections.emptyList())); + + // C. Compare + assertEquals( + expectedType.getSqlTypeName(), + actualType.getSqlTypeName(), + () -> "Return type mismatch for " + funcDef.name()); + assertEquals( + expectedType.isNullable(), + actualType.isNullable(), + () -> "Nullability mismatch for " + funcDef.name()); + } + + private static SqlOperator getOperator(final String name) { + final SqlOperator op = OPERATORS.get(name.toLowerCase()); + assertNotNull(op, "Operator not found: " + name); + return op; + } + + private static SimpleExtension.Function getFunctionDef(final String name) { + final SimpleExtension.Function func = FUNCTION_DEFS.get(name.toLowerCase()); + assertNotNull(func, "YAML Def not found: " + name); + return func; + } + + private SqlOperatorBinding createMockBinding( + final SqlOperator operator, final List argumentTypes) { + return new SqlOperatorBinding(TYPE_FACTORY, operator) { + @Override + public int getOperandCount() { + return argumentTypes.size(); + } + + @Override + public RelDataType getOperandType(final int ordinal) { + return argumentTypes.get(ordinal); + } + + @Override + public CalciteException newError(final Resources.ExInst e) { + return new CalciteException(e.toString(), null); + } + }; + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java new file mode 100644 index 000000000..69b8be3b9 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java @@ -0,0 +1,47 @@ +package io.substrait.isthmus; + +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import java.util.List; +import org.apache.calcite.prepare.Prepare; +import org.junit.jupiter.api.Test; + +class UdfSqlSubstraitTest extends PlanTestBase { + + private static final String CUSTOM_FUNCTION_PATH = "/extensions/scalar_functions_custom.yaml"; + + UdfSqlSubstraitTest() { + super(loadExtensions(List.of(CUSTOM_FUNCTION_PATH))); + } + + @Test + void customUdfTest() throws Exception { + + final Prepare.CatalogReader catalogReader = + SubstraitCreateStatementParser.processCreateStatementsToCatalog( + "CREATE TABLE t(x VARCHAR NOT NULL)"); + + FeatureBoard featureBoard = ImmutableFeatureBoard.builder().allowDynamicUdfs(true).build(); + + assertSqlSubstraitRelRoundTripLoosePojoComparison( + "SELECT regexp_extract_custom(x, 'ab') from t", catalogReader, featureBoard); + assertSqlSubstraitRelRoundTripLoosePojoComparison( + "SELECT format_text('UPPER', x) FROM t", catalogReader, featureBoard); + assertSqlSubstraitRelRoundTripLoosePojoComparison( + "SELECT system_property_get(x) FROM t", catalogReader, featureBoard); + assertSqlSubstraitRelRoundTripLoosePojoComparison( + "SELECT safe_divide_custom(10,0) FROM t", catalogReader, featureBoard); + } + + private static SimpleExtension.ExtensionCollection loadExtensions( + List yamlFunctionFiles) { + SimpleExtension.ExtensionCollection extensions = DefaultExtensionCatalog.DEFAULT_COLLECTION; + if (yamlFunctionFiles != null && !yamlFunctionFiles.isEmpty()) { + SimpleExtension.ExtensionCollection customExtensions = + SimpleExtension.load(yamlFunctionFiles); + extensions = extensions.merge(customExtensions); + } + return extensions; + } +} diff --git a/isthmus/src/test/resources/extensions/scalar_functions_custom.yaml b/isthmus/src/test/resources/extensions/scalar_functions_custom.yaml new file mode 100644 index 000000000..05595eb2f --- /dev/null +++ b/isthmus/src/test/resources/extensions/scalar_functions_custom.yaml @@ -0,0 +1,43 @@ +%YAML 1.2 +--- +urn: extension:substrait:functions_custom +scalar_functions: + - name: "regexp_extract_custom" + impls: + - args: + - name: "text" + value: string + - name: "pattern" + value: string + return: string + + - name: "format_text" + description: "Formats text based on a mode. The output is nullable if the input is." + impls: + - args: + - name: "mode" + value: string + - name: "input_text" + value: string + return: string + nullability: MIRROR + + - name: "system_property_get" + description: "Safely gets a system property. Always returns a nullable string." + impls: + - args: + - name: "property_name" + value: string + return: string? + nullability: DECLARED_OUTPUT + + - name: "safe_divide_custom" + description: "Performs division, returning NULL if the denominator is zero." + impls: + - args: + - name: "numerator" + value: i32 + - name: "denominator" + value: i32 + return: fp32? + nullability: DISCRETE