diff --git a/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java
new file mode 100644
index 000000000..377020bb3
--- /dev/null
+++ b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java
@@ -0,0 +1,49 @@
+package io.substrait.isthmus;
+
+import io.substrait.extension.SimpleExtension;
+import io.substrait.isthmus.expression.FunctionMappings;
+import java.util.List;
+import java.util.Locale;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class ExtensionUtils {
+
+ /**
+ * Extracts dynamic extensions from a collection of extensions.
+ *
+ *
A dynamic extension is a user-defined function (UDF) that is not part of the standard
+ * Substrait function catalog. These are custom functions that users define and provide at
+ * runtime, extending the built-in function set with domain-specific or application-specific
+ * operations.
+ *
+ *
This method filters out all functions that are already known to the Calcite operator table
+ * (the standard/built-in functions) and returns only the custom functions that represent new
+ * capabilities not available in the default function set.
+ *
+ *
Example: If a user defines a custom UDF "my_hash_function" that computes a
+ * proprietary hash, this would be a dynamic extension since it's not part of the standard
+ * Substrait specification.
+ *
+ * @param extensions the complete collection of extensions (both standard and custom)
+ * @return a new ExtensionCollection containing only the dynamic (custom/user-defined) functions
+ * that are not present in the standard Substrait function catalog
+ */
+ public static SimpleExtension.ExtensionCollection getDynamicExtensions(
+ SimpleExtension.ExtensionCollection extensions) {
+ Set knownFunctionNames =
+ FunctionMappings.SCALAR_SIGS.stream()
+ .map(FunctionMappings.Sig::name)
+ .collect(Collectors.toSet());
+
+ List customFunctions =
+ extensions.scalarFunctions().stream()
+ .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase(Locale.ROOT)))
+ .collect(Collectors.toList());
+
+ return SimpleExtension.ExtensionCollection.builder()
+ .scalarFunctions(customFunctions)
+ // TODO: handle aggregates and other functions
+ .build();
+ }
+}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java b/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java
index 8db29f73c..a54f24146 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java
@@ -17,4 +17,19 @@ public abstract class FeatureBoard {
public Casing unquotedCasing() {
return Casing.TO_UPPER;
}
+
+ /**
+ * Controls whether to support dynamic user-defined functions (UDFs) during SQL to Substrait plan
+ * conversion.
+ *
+ * When enabled, custom functions defined in extension YAML files are available for use in SQL
+ * queries. These functions will be dynamically converted to SQL operators during plan conversion.
+ * This feature must be explicitly enabled by users and is disabled by default.
+ *
+ * @return true if dynamic UDFs should be supported; false otherwise (default)
+ */
+ @Value.Default
+ public boolean allowDynamicUdfs() {
+ return false;
+ }
}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java b/isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java
new file mode 100644
index 000000000..3c61acd94
--- /dev/null
+++ b/isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java
@@ -0,0 +1,342 @@
+package io.substrait.isthmus;
+
+import io.substrait.extension.SimpleExtension;
+import io.substrait.function.ParameterizedType;
+import io.substrait.function.ParameterizedTypeVisitor;
+import io.substrait.function.TypeExpression;
+import io.substrait.type.Type;
+import io.substrait.type.TypeExpressionEvaluator;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.sql.SqlFunction;
+import org.apache.calcite.sql.SqlFunctionCategory;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.SqlOperatorBinding;
+import org.apache.calcite.sql.type.OperandTypes;
+import org.apache.calcite.sql.type.SqlReturnTypeInference;
+import org.apache.calcite.sql.type.SqlTypeFamily;
+import org.apache.calcite.sql.type.SqlTypeName;
+
+public final class SimpleExtensionToSqlOperator {
+
+ private static final RelDataTypeFactory DEFAULT_TYPE_FACTORY =
+ new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM);
+
+ private static final CalciteTypeVisitor CALCITE_TYPE_VISITOR = new CalciteTypeVisitor();
+
+ private SimpleExtensionToSqlOperator() {}
+
+ public static List from(SimpleExtension.ExtensionCollection collection) {
+ return from(collection, DEFAULT_TYPE_FACTORY);
+ }
+
+ public static List from(
+ SimpleExtension.ExtensionCollection collection, RelDataTypeFactory typeFactory) {
+ return from(collection, typeFactory, TypeConverter.DEFAULT);
+ }
+
+ public static List from(
+ SimpleExtension.ExtensionCollection collection,
+ RelDataTypeFactory typeFactory,
+ TypeConverter typeConverter) {
+ // TODO: add support for windows functions
+ return Stream.concat(
+ collection.scalarFunctions().stream(), collection.aggregateFunctions().stream())
+ .map(function -> toSqlFunction(function, typeFactory, typeConverter))
+ .collect(Collectors.toList());
+ }
+
+ private static SqlFunction toSqlFunction(
+ SimpleExtension.Function function,
+ RelDataTypeFactory typeFactory,
+ TypeConverter typeConverter) {
+
+ List argFamilies = new ArrayList<>();
+
+ for (SimpleExtension.Argument arg : function.requiredArguments()) {
+ if (arg instanceof SimpleExtension.ValueArgument) {
+ SimpleExtension.ValueArgument valueArg = (SimpleExtension.ValueArgument) arg;
+ SqlTypeName typeName = valueArg.value().accept(CALCITE_TYPE_VISITOR);
+ argFamilies.add(typeName.getFamily());
+ } else if (arg instanceof SimpleExtension.EnumArgument) {
+ // Treat an EnumArgument as a required string literal.
+ argFamilies.add(SqlTypeFamily.STRING);
+ }
+ }
+
+ SqlReturnTypeInference returnTypeInference =
+ new SubstraitReturnTypeInference(function, typeFactory, typeConverter);
+
+ return new SqlFunction(
+ function.name(),
+ SqlKind.OTHER_FUNCTION,
+ returnTypeInference,
+ null,
+ OperandTypes.family(argFamilies),
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+ }
+
+ private static class SubstraitReturnTypeInference implements SqlReturnTypeInference {
+
+ private final SimpleExtension.Function function;
+ private final RelDataTypeFactory typeFactory;
+ private final TypeConverter typeConverter;
+
+ private SubstraitReturnTypeInference(
+ SimpleExtension.Function function,
+ RelDataTypeFactory typeFactory,
+ TypeConverter typeConverter) {
+ this.function = function;
+ this.typeFactory = typeFactory;
+ this.typeConverter = typeConverter;
+ }
+
+ @Override
+ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
+ List substraitArgTypes =
+ opBinding.collectOperandTypes().stream()
+ .map(typeConverter::toSubstrait)
+ .collect(Collectors.toList());
+
+ TypeExpression returnExpression = function.returnType();
+ Type resolvedSubstraitType =
+ TypeExpressionEvaluator.evaluateExpression(
+ returnExpression, function.args(), substraitArgTypes);
+
+ boolean finalIsNullable;
+ switch (function.nullability()) {
+ case MIRROR:
+ // If any input is nullable, the output is nullable.
+ finalIsNullable =
+ opBinding.collectOperandTypes().stream().anyMatch(RelDataType::isNullable);
+ break;
+ case DISCRETE:
+ case DECLARED_OUTPUT:
+ default:
+ // Use the nullability declared on the resolved Substrait type.
+ finalIsNullable = resolvedSubstraitType.nullable();
+ break;
+ }
+
+ RelDataType baseCalciteType = typeConverter.toCalcite(typeFactory, resolvedSubstraitType);
+
+ return typeFactory.createTypeWithNullability(baseCalciteType, finalIsNullable);
+ }
+ }
+
+ private static class CalciteTypeVisitor
+ extends ParameterizedTypeVisitor.ParameterizedTypeThrowsVisitor<
+ SqlTypeName, RuntimeException> {
+
+ private CalciteTypeVisitor() {
+ super("Type not supported for Calcite conversion.");
+ }
+
+ @Override
+ public SqlTypeName visit(Type.Bool expr) {
+ return SqlTypeName.BOOLEAN;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.I8 expr) {
+ return SqlTypeName.TINYINT;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.I16 expr) {
+ return SqlTypeName.SMALLINT;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.I32 expr) {
+ return SqlTypeName.INTEGER;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.I64 expr) {
+ return SqlTypeName.BIGINT;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.FP32 expr) {
+ return SqlTypeName.FLOAT;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.FP64 expr) {
+ return SqlTypeName.DOUBLE;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.Str expr) {
+ return SqlTypeName.VARCHAR;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.Binary expr) {
+ return SqlTypeName.VARBINARY;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.Date expr) {
+ return SqlTypeName.DATE;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.Time expr) {
+ return SqlTypeName.TIME;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.TimestampTZ expr) {
+ return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.Timestamp expr) {
+ return SqlTypeName.TIMESTAMP;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.IntervalYear year) {
+ return SqlTypeName.INTERVAL_YEAR_MONTH;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.IntervalDay day) {
+ return SqlTypeName.INTERVAL_DAY;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.UUID expr) {
+ return SqlTypeName.UUID;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.Struct struct) {
+ return SqlTypeName.ROW;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.ListType listType) {
+ return SqlTypeName.ARRAY;
+ }
+
+ @Override
+ public SqlTypeName visit(Type.Map map) {
+ return SqlTypeName.MAP;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.FixedChar expr) {
+ return SqlTypeName.CHAR;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.VarChar expr) {
+ return SqlTypeName.VARCHAR;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.FixedBinary expr) {
+ return SqlTypeName.BINARY;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.Decimal expr) {
+ return SqlTypeName.DECIMAL;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.Struct expr) {
+ return SqlTypeName.ROW;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.ListType expr) {
+ return SqlTypeName.ARRAY;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.Map expr) {
+ return SqlTypeName.MAP;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.PrecisionTimestamp expr) {
+ return SqlTypeName.TIMESTAMP;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.PrecisionTimestampTZ expr) {
+ return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.PrecisionTime expr) {
+ return SqlTypeName.TIME;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.IntervalDay expr) {
+ return SqlTypeName.INTERVAL_DAY;
+ }
+
+ @Override
+ public SqlTypeName visit(ParameterizedType.StringLiteral expr) {
+ String type = expr.value().toUpperCase();
+
+ if (type.startsWith("ANY")) {
+ return SqlTypeName.ANY;
+ }
+
+ switch (type) {
+ case "BOOLEAN":
+ return SqlTypeName.BOOLEAN;
+ case "I8":
+ return SqlTypeName.TINYINT;
+ case "I16":
+ return SqlTypeName.SMALLINT;
+ case "I32":
+ return SqlTypeName.INTEGER;
+ case "I64":
+ return SqlTypeName.BIGINT;
+ case "FP32":
+ return SqlTypeName.FLOAT;
+ case "FP64":
+ return SqlTypeName.DOUBLE;
+ case "STRING":
+ return SqlTypeName.VARCHAR;
+ case "BINARY":
+ return SqlTypeName.VARBINARY;
+ case "TIMESTAMP":
+ return SqlTypeName.TIMESTAMP;
+ case "TIMESTAMP_TZ":
+ return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
+ case "DATE":
+ return SqlTypeName.DATE;
+ case "TIME":
+ return SqlTypeName.TIME;
+ case "UUID":
+ return SqlTypeName.UUID;
+ default:
+ if (type.startsWith("DECIMAL")) {
+ return SqlTypeName.DECIMAL;
+ }
+ if (type.startsWith("STRUCT")) {
+ return SqlTypeName.ROW;
+ }
+ if (type.startsWith("LIST")) {
+ return SqlTypeName.ARRAY;
+ }
+ return super.visit(expr);
+ }
+ }
+ }
+}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java
index 671deabe5..f667deab0 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java
@@ -19,8 +19,7 @@
import org.apache.calcite.sql2rel.SqlToRelConverter;
public class SqlConverterBase {
- protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
- DefaultExtensionCatalog.DEFAULT_COLLECTION;
+ protected final SimpleExtension.ExtensionCollection extensionCollection;
public static final CalciteConnectionConfig CONNECTION_CONFIG =
CalciteConnectionConfig.DEFAULT.set(
@@ -36,7 +35,8 @@ public class SqlConverterBase {
protected static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
final FeatureBoard featureBoard;
- protected SqlConverterBase(FeatureBoard features) {
+ protected SqlConverterBase(
+ FeatureBoard features, SimpleExtension.ExtensionCollection extensionCollection) {
this.factory = SubstraitTypeSystem.TYPE_FACTORY;
this.config =
CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false");
@@ -55,5 +55,11 @@ protected SqlConverterBase(FeatureBoard features) {
.withUnquotedCasing(featureBoard.unquotedCasing())
.withParserFactory(SqlDdlParserImpl.FACTORY)
.withConformance(SqlConformanceEnum.LENIENT);
+
+ this.extensionCollection = extensionCollection;
+ }
+
+ protected SqlConverterBase(FeatureBoard features) {
+ this(features, DefaultExtensionCatalog.DEFAULT_COLLECTION);
}
}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java
index c32fab07c..3d45f8bde 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java
@@ -3,6 +3,7 @@
import io.substrait.extendedexpression.ExtendedExpression;
import io.substrait.extendedexpression.ExtendedExpressionProtoConverter;
import io.substrait.extendedexpression.ImmutableExtendedExpression.Builder;
+import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.calcite.SubstraitTable;
import io.substrait.isthmus.expression.RexExpressionConverter;
@@ -34,12 +35,12 @@ public class SqlExpressionToSubstrait extends SqlConverterBase {
protected final RexExpressionConverter rexConverter;
public SqlExpressionToSubstrait() {
- this(FEATURES_DEFAULT, EXTENSION_COLLECTION);
+ this(FEATURES_DEFAULT, DefaultExtensionCatalog.DEFAULT_COLLECTION);
}
public SqlExpressionToSubstrait(
FeatureBoard features, SimpleExtension.ExtensionCollection extensions) {
- super(features);
+ super(features, extensions);
ScalarFunctionConverter scalarFunctionConverter =
new ScalarFunctionConverter(extensions.scalarFunctions(), factory);
this.rexConverter = new RexExpressionConverter(scalarFunctionConverter);
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
index 3e19ca58c..abe935a75 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
@@ -1,22 +1,49 @@
package io.substrait.isthmus;
+import io.substrait.extension.DefaultExtensionCatalog;
+import io.substrait.extension.SimpleExtension;
+import io.substrait.isthmus.calcite.SubstraitOperatorTable;
import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import io.substrait.plan.ImmutablePlan.Builder;
import io.substrait.plan.Plan;
import io.substrait.plan.Plan.Version;
import io.substrait.plan.PlanProtoConverter;
+import java.util.List;
import org.apache.calcite.prepare.Prepare;
+import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
+import org.apache.calcite.sql.util.SqlOperatorTables;
/** Take a SQL statement and a set of table definitions and return a substrait plan. */
public class SqlToSubstrait extends SqlConverterBase {
+ private final SqlOperatorTable operatorTable;
public SqlToSubstrait() {
- this(null);
+ this(DefaultExtensionCatalog.DEFAULT_COLLECTION, null);
}
public SqlToSubstrait(FeatureBoard features) {
- super(features);
+ this(DefaultExtensionCatalog.DEFAULT_COLLECTION, features);
+ }
+
+ public SqlToSubstrait(SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
+ super(features, extensions);
+
+ if (featureBoard.allowDynamicUdfs()) {
+ SimpleExtension.ExtensionCollection dynamicExtensionCollection =
+ ExtensionUtils.getDynamicExtensions(extensions);
+ if (!dynamicExtensionCollection.scalarFunctions().isEmpty()
+ || !dynamicExtensionCollection.aggregateFunctions().isEmpty()) {
+ List generatedDynamicOperators =
+ SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, this.factory);
+ this.operatorTable =
+ SqlOperatorTables.chain(
+ SubstraitOperatorTable.INSTANCE, SqlOperatorTables.of(generatedDynamicOperators));
+ return;
+ }
+ }
+ this.operatorTable = SubstraitOperatorTable.INSTANCE;
}
/**
@@ -53,8 +80,8 @@ public Plan convert(String sqlStatements, Prepare.CatalogReader catalogReader)
builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build());
// TODO: consider case in which one sql passes conversion while others don't
- SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader).stream()
- .map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
+ SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader, operatorTable).stream()
+ .map(root -> SubstraitRelVisitor.convert(root, extensionCollection, featureBoard))
.forEach(root -> builder.addRoots(root));
return builder.build();
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java
index 96e091578..47daf97e2 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java
@@ -1,7 +1,5 @@
package io.substrait.isthmus;
-import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
-
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Range;
import com.google.common.collect.RangeMap;
@@ -14,6 +12,7 @@
import io.substrait.isthmus.calcite.rel.CreateView;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.ExpressionRexConverter;
+import io.substrait.isthmus.expression.FunctionMappings;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.relation.AbstractDdlRel;
@@ -110,10 +109,18 @@ public SubstraitRelNodeConverter(
SimpleExtension.ExtensionCollection extensions,
RelDataTypeFactory typeFactory,
RelBuilder relBuilder) {
+ this(extensions, typeFactory, relBuilder, ImmutableFeatureBoard.builder().build());
+ }
+
+ public SubstraitRelNodeConverter(
+ SimpleExtension.ExtensionCollection extensions,
+ RelDataTypeFactory typeFactory,
+ RelBuilder relBuilder,
+ FeatureBoard featureBoard) {
this(
typeFactory,
relBuilder,
- new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory),
+ createScalarFunctionConverter(extensions, typeFactory, featureBoard.allowDynamicUdfs()),
new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory),
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory),
TypeConverter.DEFAULT);
@@ -155,11 +162,68 @@ public SubstraitRelNodeConverter(
this.expressionRexConverter.setRelNodeConverter(this);
}
+ private static ScalarFunctionConverter createScalarFunctionConverter(
+ SimpleExtension.ExtensionCollection extensions,
+ RelDataTypeFactory typeFactory,
+ boolean allowDynamicUdfs) {
+
+ List additionalSignatures;
+
+ if (allowDynamicUdfs) {
+ java.util.Set knownFunctionNames =
+ FunctionMappings.SCALAR_SIGS.stream()
+ .map(FunctionMappings.Sig::name)
+ .collect(Collectors.toSet());
+
+ List dynamicFunctions =
+ extensions.scalarFunctions().stream()
+ .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase()))
+ .collect(Collectors.toList());
+
+ if (dynamicFunctions.isEmpty()) {
+ additionalSignatures = Collections.emptyList();
+ } else {
+ SimpleExtension.ExtensionCollection dynamicExtensionCollection =
+ SimpleExtension.ExtensionCollection.builder().scalarFunctions(dynamicFunctions).build();
+
+ List dynamicOperators =
+ SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory);
+
+ additionalSignatures =
+ dynamicOperators.stream()
+ .map(op -> FunctionMappings.s(op, op.getName()))
+ .collect(Collectors.toList());
+ }
+ } else {
+ additionalSignatures = Collections.emptyList();
+ }
+
+ return new ScalarFunctionConverter(
+ extensions.scalarFunctions(), additionalSignatures, typeFactory, TypeConverter.DEFAULT);
+ }
+
public static RelNode convert(
Rel relRoot,
RelOptCluster relOptCluster,
Prepare.CatalogReader catalogReader,
- SqlParser.Config parserConfig) {
+ SqlParser.Config parserConfig,
+ SimpleExtension.ExtensionCollection extensions) {
+ return convert(
+ relRoot,
+ relOptCluster,
+ catalogReader,
+ parserConfig,
+ extensions,
+ ImmutableFeatureBoard.builder().build());
+ }
+
+ public static RelNode convert(
+ Rel relRoot,
+ RelOptCluster relOptCluster,
+ Prepare.CatalogReader catalogReader,
+ SqlParser.Config parserConfig,
+ SimpleExtension.ExtensionCollection extensions,
+ FeatureBoard featureBoard) {
RelBuilder relBuilder =
RelBuilder.create(
Frameworks.newConfigBuilder()
@@ -171,7 +235,7 @@ public static RelNode convert(
return relRoot.accept(
new SubstraitRelNodeConverter(
- EXTENSION_COLLECTION, relOptCluster.getTypeFactory(), relBuilder),
+ extensions, relOptCluster.getTypeFactory(), relBuilder, featureBoard),
Context.newContext());
}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
index 3dde72c13..5c3bb7a72 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
@@ -9,6 +9,7 @@
import io.substrait.isthmus.calcite.rel.CreateView;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.CallConverters;
+import io.substrait.isthmus.expression.FunctionMappings;
import io.substrait.isthmus.expression.LiteralConverter;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
@@ -62,6 +63,7 @@
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.immutables.value.Value;
@@ -88,10 +90,31 @@ public SubstraitRelVisitor(
RelDataTypeFactory typeFactory,
SimpleExtension.ExtensionCollection extensions,
FeatureBoard features) {
+
this.typeConverter = TypeConverter.DEFAULT;
- ArrayList converters = new ArrayList();
+ ArrayList converters = new ArrayList<>();
converters.addAll(CallConverters.defaults(typeConverter));
- converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory));
+
+ if (features.allowDynamicUdfs()) {
+ SimpleExtension.ExtensionCollection dynamicExtensionCollection =
+ ExtensionUtils.getDynamicExtensions(extensions);
+ List dynamicOperators =
+ SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory);
+
+ List additionalSignatures =
+ dynamicOperators.stream()
+ .map(op -> FunctionMappings.s(op, op.getName()))
+ .collect(Collectors.toList());
+ converters.add(
+ new ScalarFunctionConverter(
+ extensions.scalarFunctions(),
+ additionalSignatures,
+ typeFactory,
+ TypeConverter.DEFAULT));
+ } else {
+ converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory));
+ }
+
converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory)));
this.aggregateFunctionConverter =
new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory);
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java
index 8dcfbf9e0..772a3e192 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java
@@ -38,6 +38,7 @@ public class SubstraitToCalcite {
protected final RelDataTypeFactory typeFactory;
protected final TypeConverter typeConverter;
protected final Prepare.CatalogReader catalogReader;
+ protected final FeatureBoard featureBoard;
public SubstraitToCalcite(
SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) {
@@ -63,10 +64,25 @@ public SubstraitToCalcite(
RelDataTypeFactory typeFactory,
TypeConverter typeConverter,
Prepare.CatalogReader catalogReader) {
+ this(
+ extensions,
+ typeFactory,
+ typeConverter,
+ catalogReader,
+ ImmutableFeatureBoard.builder().build());
+ }
+
+ public SubstraitToCalcite(
+ SimpleExtension.ExtensionCollection extensions,
+ RelDataTypeFactory typeFactory,
+ TypeConverter typeConverter,
+ Prepare.CatalogReader catalogReader,
+ FeatureBoard featureBoard) {
this.extensions = extensions;
this.typeFactory = typeFactory;
this.typeConverter = typeConverter;
this.catalogReader = catalogReader;
+ this.featureBoard = featureBoard;
}
/**
@@ -94,7 +110,7 @@ protected RelBuilder createRelBuilder(CalciteSchema schema) {
* Override this method to customize the {@link SubstraitRelNodeConverter}.
*/
protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) {
- return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder);
+ return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder, featureBoard);
}
/**
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java
index 421b45317..e327ab007 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java
@@ -1,5 +1,6 @@
package io.substrait.isthmus;
+import io.substrait.extension.SimpleExtension;
import io.substrait.relation.Rel;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelNode;
@@ -10,7 +11,12 @@ public SubstraitToSql() {
super(FEATURES_DEFAULT);
}
+ public SubstraitToSql(SimpleExtension.ExtensionCollection extensions) {
+ super(FEATURES_DEFAULT, extensions);
+ }
+
public RelNode substraitRelToCalciteRel(Rel relRoot, Prepare.CatalogReader catalog) {
- return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, catalog, parserConfig);
+ return SubstraitRelNodeConverter.convert(
+ relRoot, relOptCluster, catalog, parserConfig, extensionCollection);
}
}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java
index d738ef157..b5604d4d9 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java
@@ -25,7 +25,6 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
-import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@@ -129,8 +128,7 @@ public FunctionConverter(
.collect(
Multimaps.toMultimap(
FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create));
- IdentityHashMap matcherMap =
- new IdentityHashMap();
+ Map matcherMap = new HashMap<>();
for (String key : nameToFn.keySet()) {
Collection sigs = calciteOperators.get(key);
if (sigs.isEmpty()) {
diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java
index a87e29563..b46d30e7c 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java
@@ -15,6 +15,7 @@
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql2rel.SqlToRelConverter;
@@ -41,6 +42,23 @@ public static RelRoot convertQuery(String sqlStatement, Prepare.CatalogReader ca
return convertQuery(sqlStatement, catalogReader, validator, createDefaultRelOptCluster());
}
+ /**
+ * Converts a SQL statement to a Calcite {@link RelRoot}.
+ *
+ * @param sqlStatement a SQL statement string
+ * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in
+ * the SQL statement
+ * @param operatorTable the {@link SqlOperatorTable} for controlling valid operators
+ * @return a {@link RelRoot} corresponding to the given SQL statement
+ * @throws SqlParseException if there is an error while parsing the SQL statement
+ */
+ public static RelRoot convertQuery(
+ String sqlStatement, Prepare.CatalogReader catalogReader, SqlOperatorTable operatorTable)
+ throws SqlParseException {
+ SqlValidator validator = new SubstraitSqlValidator(catalogReader, operatorTable);
+ return convertQuery(sqlStatement, catalogReader, validator, createDefaultRelOptCluster());
+ }
+
/**
* Converts a SQL statement to a Calcite {@link RelRoot}.
*
@@ -72,6 +90,24 @@ public static RelRoot convertQuery(
return relRoots.get(0);
}
+ /**
+ * Converts one or more SQL statements to a List of {@link RelRoot}, with one {@link RelRoot} per
+ * statement.
+ *
+ * @param sqlStatements a string containing one or more SQL statements
+ * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in
+ * the SQL statements
+ * @param operatorTable the {@link SqlOperatorTable} for controlling valid operators
+ * @return a list of {@link RelRoot}s corresponding to the given SQL statements
+ * @throws SqlParseException if there is an error while parsing the SQL statements
+ */
+ public static List convertQueries(
+ String sqlStatements, Prepare.CatalogReader catalogReader, SqlOperatorTable operatorTable)
+ throws SqlParseException {
+ SqlValidator validator = new SubstraitSqlValidator(catalogReader, operatorTable);
+ return convertQueries(sqlStatements, catalogReader, validator, createDefaultRelOptCluster());
+ }
+
/**
* Converts one or more SQL statements to a List of {@link RelRoot}, with one {@link RelRoot} per
* statement.
diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java
index 52be2d6a5..07b6edda8 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java
@@ -2,6 +2,7 @@
import io.substrait.isthmus.calcite.SubstraitOperatorTable;
import org.apache.calcite.prepare.Prepare;
+import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorImpl;
@@ -12,4 +13,8 @@ public class SubstraitSqlValidator extends SqlValidatorImpl {
public SubstraitSqlValidator(Prepare.CatalogReader catalogReader) {
super(SubstraitOperatorTable.INSTANCE, catalogReader, catalogReader.getTypeFactory(), CONFIG);
}
+
+ public SubstraitSqlValidator(Prepare.CatalogReader catalogReader, SqlOperatorTable opTable) {
+ super(opTable, catalogReader, catalogReader.getTypeFactory(), CONFIG);
+ }
}
diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java
index e21d3b653..94f77c6a9 100644
--- a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java
+++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java
@@ -1,6 +1,5 @@
package io.substrait.isthmus;
-import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static io.substrait.isthmus.SubstraitTypeSystem.YEAR_MONTH_INTERVAL;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -11,6 +10,8 @@
import io.substrait.expression.Expression.Literal;
import io.substrait.expression.Expression.TimestampLiteral;
import io.substrait.expression.ExpressionCreator;
+import io.substrait.extension.DefaultExtensionCatalog;
+import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.SubstraitRelNodeConverter.Context;
import io.substrait.isthmus.expression.ExpressionRexConverter;
import io.substrait.isthmus.expression.RexExpressionConverter;
@@ -35,6 +36,8 @@
import org.junit.jupiter.api.Test;
class CalciteLiteralTest extends CalciteObjs {
+ protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
+ DefaultExtensionCatalog.DEFAULT_COLLECTION;
private final ScalarFunctionConverter scalarFunctionConverter =
new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), type);
diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java
index dc3cfdbbc..7a0d4ddc5 100644
--- a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java
+++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java
@@ -1,12 +1,13 @@
package io.substrait.isthmus;
-import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.ImmutableAggregateFunctionInvocation;
+import io.substrait.extension.DefaultExtensionCatalog;
+import io.substrait.extension.SimpleExtension;
import io.substrait.relation.Aggregate;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Rel;
@@ -17,6 +18,8 @@
import org.junit.jupiter.api.Test;
class ComplexAggregateTest extends PlanTestBase {
+ protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
+ DefaultExtensionCatalog.DEFAULT_COLLECTION;
final TypeCreator R = TypeCreator.of(false);
SubstraitBuilder b = new SubstraitBuilder(extensions);
diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java
index afe088d30..300786f21 100644
--- a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java
+++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java
@@ -1,10 +1,11 @@
package io.substrait.isthmus;
-import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.Expression;
+import io.substrait.extension.DefaultExtensionCatalog;
+import io.substrait.extension.SimpleExtension;
import io.substrait.relation.Rel;
import io.substrait.type.TypeCreator;
import java.io.PrintWriter;
@@ -20,6 +21,9 @@
class ComplexSortTest extends PlanTestBase {
+ private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
+ DefaultExtensionCatalog.DEFAULT_COLLECTION;
+
final TypeCreator R = TypeCreator.of(false);
SubstraitBuilder b = new SubstraitBuilder(extensions);
diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java
index ed63f0c47..a99c5ff35 100644
--- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java
+++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java
@@ -1,8 +1,9 @@
package io.substrait.isthmus;
-import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import io.substrait.extension.DefaultExtensionCatalog;
+import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import io.substrait.plan.Plan;
@@ -13,6 +14,9 @@
class NameRoundtripTest extends PlanTestBase {
+ private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
+ DefaultExtensionCatalog.DEFAULT_COLLECTION;
+
@Test
void preserveNamesFromSql() throws Exception {
String createStatement = "CREATE TABLE foo(a BIGINT, b BIGINT)";
diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java
index b408303a0..8f392aae2 100644
--- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java
+++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java
@@ -1,8 +1,9 @@
package io.substrait.isthmus;
-import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import io.substrait.extension.DefaultExtensionCatalog;
+import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import java.io.IOException;
import org.apache.calcite.plan.hep.HepPlanner;
@@ -16,6 +17,9 @@
class OptimizerIntegrationTest extends PlanTestBase {
+ private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
+ DefaultExtensionCatalog.DEFAULT_COLLECTION;
+
@Test
void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOException {
String query =
diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java
index 7cf5c38c5..a37916bc4 100644
--- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java
+++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java
@@ -8,6 +8,7 @@
import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import io.substrait.dsl.SubstraitBuilder;
+import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
@@ -38,12 +39,12 @@
import org.apache.calcite.tools.RelBuilder;
public class PlanTestBase {
- protected final SimpleExtension.ExtensionCollection extensions =
- SqlConverterBase.EXTENSION_COLLECTION;
+ protected final SimpleExtension.ExtensionCollection extensions;
+
protected final RelCreator creator = new RelCreator();
protected final RelBuilder builder = creator.createRelBuilder();
protected final RelDataTypeFactory typeFactory = creator.typeFactory();
- protected final SubstraitBuilder substraitBuilder = new SubstraitBuilder(extensions);
+ protected final SubstraitBuilder substraitBuilder;
protected static final TypeCreator R = TypeCreator.of(false);
protected static final TypeCreator N = TypeCreator.of(true);
@@ -63,6 +64,15 @@ public class PlanTestBase {
protected static CalciteCatalogReader TPCDS_CATALOG =
PlanTestBase.schemaToCatalog("tpcds", TPCDS_SCHEMA);
+ protected PlanTestBase() {
+ this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
+ }
+
+ protected PlanTestBase(SimpleExtension.ExtensionCollection extensions) {
+ this.extensions = extensions;
+ this.substraitBuilder = new SubstraitBuilder(extensions);
+ }
+
public static String asString(String resource) throws IOException {
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
}
@@ -142,6 +152,76 @@ protected RelRoot assertSqlSubstraitRelRoundTrip(
return relRoot2;
}
+ /**
+ * Verifies that the given query can be converted through multiple round trips, with loose POJO
+ * comparison.
+ *
+ * "Loose" here means not comparing the initial POJO (from SQL→Substrait conversion) to the
+ * first POJO after the round trip (from Substrait→Calcite→Substrait conversion), due to optimizer
+ * differences between:
+ *
+ *
+ * - SqlNode→RelRoot conversion (SQL→Substrait path)
+ *
- RelBuilder/RexBuilder optimization (Substrait→Calcite path)
+ *
+ *
+ * Instead, this method compares the second and third round-trip POJOs, ensuring that
+ * subsequent round trips produce stable results.
+ *
+ * @param query the SQL query to test
+ * @param catalogReader the Calcite catalog with table definitions
+ * @param featureBoard optional FeatureBoard to control conversion behavior (e.g., dynamic UDFs).
+ * If null, a default FeatureBoard is used.
+ */
+ protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison(
+ String query, Prepare.CatalogReader catalogReader, FeatureBoard featureBoard)
+ throws Exception {
+ // Use provided FeatureBoard, or create default if null
+ FeatureBoard features =
+ featureBoard != null ? featureBoard : ImmutableFeatureBoard.builder().build();
+
+ SubstraitToCalcite substraitToCalcite =
+ new SubstraitToCalcite(extensions, typeFactory, TypeConverter.DEFAULT, null, features);
+ SqlToSubstrait s = new SqlToSubstrait(extensions, features);
+
+ // 1. SQL -> Substrait Plan
+ Plan plan1 = s.convert(query, catalogReader);
+
+ // 2. Substrait Plan -> Substrait Root (POJO 1)
+ Plan.Root pojo1 = plan1.getRoots().get(0);
+
+ // 3. Substrait Root -> Calcite RelNode
+ RelRoot relRoot2 = substraitToCalcite.convert(pojo1);
+
+ // 4. Calcite RelNode -> Substrait Root (POJO 2)
+ Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, extensions, features);
+
+ // Note: pojo1 and pojo2 may differ due to different optimization strategies applied by:
+ // - SqlNode->RelRoot conversion during SQL->Substrait conversion
+ // - RelBuilder/RexBuilder optimization during Substrait->Calcite conversion
+ // This is expected, so we don't compare pojo1 and pojo2.
+
+ // 5. Substrait Root 2 -> Calcite RelNode
+ RelRoot relRoot3 = substraitToCalcite.convert(pojo2);
+
+ // 6. Calcite RelNode -> Substrait Root (POJO 3)
+ Plan.Root pojo3 = SubstraitRelVisitor.convert(relRoot3, extensions, features);
+
+ // Verify that subsequent round trips are stable (pojo2 and pojo3 should be identical)
+ assertEquals(pojo2, pojo3);
+ return relRoot2;
+ }
+
+ /**
+ * Convenience overload of {@link #assertSqlSubstraitRelRoundTripLoosePojoComparison(String,
+ * Prepare.CatalogReader, FeatureBoard)} with default FeatureBoard behavior (no dynamic UDFs).
+ */
+ protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison(
+ String query, Prepare.CatalogReader catalogReader) throws Exception {
+ return assertSqlSubstraitRelRoundTripLoosePojoComparison(
+ query, catalogReader, ImmutableFeatureBoard.builder().build());
+ }
+
@Beta
protected void assertFullRoundTrip(String query) throws SqlParseException {
assertFullRoundTrip(query, TPCH_CATALOG);
diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java
new file mode 100644
index 000000000..030e501d8
--- /dev/null
+++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java
@@ -0,0 +1,232 @@
+package io.substrait.isthmus;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.substrait.extension.SimpleExtension;
+import io.substrait.type.Type;
+import io.substrait.type.TypeExpressionEvaluator;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.runtime.CalciteException;
+import org.apache.calcite.runtime.Resources;
+import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.SqlOperatorBinding;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.validate.SqlValidatorException;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+
+/** Tests for conversion of SimpleExtension function definitions to Calcite SqlOperators. */
+class SimpleExtensionToSqlOperatorTest {
+
+ private static final String CUSTOM_FUNCTION_PATH = "/extensions/scalar_functions_custom.yaml";
+ private static final RelDataTypeFactory TYPE_FACTORY = SubstraitTypeSystem.TYPE_FACTORY;
+
+ private static final Map FUNCTION_DEFS;
+ private static final Map OPERATORS;
+
+ static {
+ final SimpleExtension.ExtensionCollection extensions =
+ SimpleExtension.load(
+ CUSTOM_FUNCTION_PATH,
+ SimpleExtensionToSqlOperatorTest.class.getResourceAsStream(CUSTOM_FUNCTION_PATH));
+
+ FUNCTION_DEFS =
+ extensions.scalarFunctions().stream()
+ .collect(
+ Collectors.toUnmodifiableMap(f -> f.name().toLowerCase(), Function.identity()));
+
+ OPERATORS =
+ SimpleExtensionToSqlOperator.from(extensions).stream()
+ .collect(
+ Collectors.toUnmodifiableMap(
+ op -> op.getName().toLowerCase(), Function.identity()));
+ }
+
+ /** Test Specification. */
+ record TestSpec(
+ String name,
+ int minArgs,
+ int maxArgs,
+ SimpleExtension.Nullability nullability,
+ List expectedArgTypes) {}
+
+ @ParameterizedTest
+ @MethodSource("provideTestSpecs")
+ void testCustomUdfConversion(final TestSpec spec) {
+ final SqlOperator operator = getOperator(spec.name);
+ final SimpleExtension.Function funcDef = getFunctionDef(spec.name);
+
+ assertEquals(
+ spec.minArgs,
+ operator.getOperandCountRange().getMin(),
+ () -> spec.name + ": Incorrect min args");
+ assertEquals(
+ spec.maxArgs,
+ operator.getOperandCountRange().getMax(),
+ () -> spec.name + ": Incorrect max args");
+
+ if (spec.nullability != null) {
+ assertEquals(
+ spec.nullability, funcDef.nullability(), () -> spec.name + ": Incorrect nullability");
+ }
+
+ if (!spec.expectedArgTypes.isEmpty()) {
+ verifyAllowedSignatures(operator, spec.expectedArgTypes);
+ }
+
+ verifyReturnTypeConsistency(operator, funcDef);
+ }
+
+ private static Stream provideTestSpecs() {
+ return Stream.of(
+ new TestSpec("REGEXP_EXTRACT_CUSTOM", 2, 2, null, List.of("VARCHAR", "VARCHAR")),
+ new TestSpec(
+ "FORMAT_TEXT", 2, 2, SimpleExtension.Nullability.MIRROR, List.of("VARCHAR", "VARCHAR")),
+ new TestSpec(
+ "SYSTEM_PROPERTY_GET",
+ 1,
+ 1,
+ SimpleExtension.Nullability.DECLARED_OUTPUT,
+ List.of("VARCHAR")),
+ new TestSpec(
+ "SAFE_DIVIDE_CUSTOM",
+ 2,
+ 2,
+ SimpleExtension.Nullability.DISCRETE,
+ List.of("INTEGER", "INTEGER")));
+ }
+
+ /**
+ * Parses the operator's signature string and checks that the types match the expected list
+ * index-by-index.
+ */
+ private void verifyAllowedSignatures(
+ final SqlOperator operator, final List expectedArgTypes) {
+ assertNotNull(operator.getOperandTypeChecker(), "Operand type checker is null");
+
+ // e.g., "SAFE_DIVIDE_CUSTOM(, )"
+ final String signature =
+ operator
+ .getOperandTypeChecker()
+ .getAllowedSignatures(operator, operator.getName())
+ .toUpperCase();
+
+ // Regex to capture arguments inside parentheses: NAME(ARG1, ARG2)
+ final Pattern pattern = Pattern.compile(".*?\\((.*)\\).*");
+ final Matcher matcher = pattern.matcher(signature);
+
+ assertTrue(matcher.matches(), () -> "Signature format not recognized: " + signature);
+
+ // Split args by comma (assuming simple types for this test suite)
+ final String argsPart = matcher.group(1);
+ final List actualArgTypes =
+ Arrays.stream(argsPart.split(",")).map(String::trim).toList();
+
+ assertEquals(
+ expectedArgTypes.size(),
+ actualArgTypes.size(),
+ () -> "Signature argument count mismatch. Signature: " + signature);
+
+ // Positional Check
+ for (int i = 0; i < expectedArgTypes.size(); i++) {
+ final String expected = expectedArgTypes.get(i);
+ final String actual = actualArgTypes.get(i);
+
+ final SqlTypeName sqlTypeName = SqlTypeName.valueOf(expected);
+ final String familyName = sqlTypeName.getFamily().toString();
+
+ // Check if the actual slot matches the specific type OR the generic family
+ // e.g. Expected "INTEGER" matches actual "" or "INTEGER"
+ final boolean match = actual.contains(expected) || actual.contains(familyName);
+
+ final int index = i;
+ assertTrue(
+ match,
+ () ->
+ "Argument mismatch at index "
+ + index
+ + ".\n"
+ + "Expected: "
+ + expected
+ + " (Family: "
+ + familyName
+ + ")\n"
+ + "Actual: "
+ + actual
+ + "\n"
+ + "Full Signature: "
+ + signature);
+ }
+ }
+
+ private void verifyReturnTypeConsistency(
+ final SqlOperator operator, final SimpleExtension.Function funcDef) {
+ assertNotNull(operator.getReturnTypeInference(), "Return type inference is null");
+
+ // A. Expected: Evaluate YAML return type -> Convert to Calcite
+ final Type yamlReturnType =
+ TypeExpressionEvaluator.evaluateExpression(
+ funcDef.returnType(), funcDef.args(), Collections.emptyList());
+ final RelDataType expectedType = TypeConverter.DEFAULT.toCalcite(TYPE_FACTORY, yamlReturnType);
+
+ // B. Actual: Infer from Operator (using empty binding, sufficient for static types)
+ final RelDataType actualType =
+ operator
+ .getReturnTypeInference()
+ .inferReturnType(createMockBinding(operator, Collections.emptyList()));
+
+ // C. Compare
+ assertEquals(
+ expectedType.getSqlTypeName(),
+ actualType.getSqlTypeName(),
+ () -> "Return type mismatch for " + funcDef.name());
+ assertEquals(
+ expectedType.isNullable(),
+ actualType.isNullable(),
+ () -> "Nullability mismatch for " + funcDef.name());
+ }
+
+ private static SqlOperator getOperator(final String name) {
+ final SqlOperator op = OPERATORS.get(name.toLowerCase());
+ assertNotNull(op, "Operator not found: " + name);
+ return op;
+ }
+
+ private static SimpleExtension.Function getFunctionDef(final String name) {
+ final SimpleExtension.Function func = FUNCTION_DEFS.get(name.toLowerCase());
+ assertNotNull(func, "YAML Def not found: " + name);
+ return func;
+ }
+
+ private SqlOperatorBinding createMockBinding(
+ final SqlOperator operator, final List argumentTypes) {
+ return new SqlOperatorBinding(TYPE_FACTORY, operator) {
+ @Override
+ public int getOperandCount() {
+ return argumentTypes.size();
+ }
+
+ @Override
+ public RelDataType getOperandType(final int ordinal) {
+ return argumentTypes.get(ordinal);
+ }
+
+ @Override
+ public CalciteException newError(final Resources.ExInst e) {
+ return new CalciteException(e.toString(), null);
+ }
+ };
+ }
+}
diff --git a/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java
new file mode 100644
index 000000000..69b8be3b9
--- /dev/null
+++ b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java
@@ -0,0 +1,47 @@
+package io.substrait.isthmus;
+
+import io.substrait.extension.DefaultExtensionCatalog;
+import io.substrait.extension.SimpleExtension;
+import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
+import java.util.List;
+import org.apache.calcite.prepare.Prepare;
+import org.junit.jupiter.api.Test;
+
+class UdfSqlSubstraitTest extends PlanTestBase {
+
+ private static final String CUSTOM_FUNCTION_PATH = "/extensions/scalar_functions_custom.yaml";
+
+ UdfSqlSubstraitTest() {
+ super(loadExtensions(List.of(CUSTOM_FUNCTION_PATH)));
+ }
+
+ @Test
+ void customUdfTest() throws Exception {
+
+ final Prepare.CatalogReader catalogReader =
+ SubstraitCreateStatementParser.processCreateStatementsToCatalog(
+ "CREATE TABLE t(x VARCHAR NOT NULL)");
+
+ FeatureBoard featureBoard = ImmutableFeatureBoard.builder().allowDynamicUdfs(true).build();
+
+ assertSqlSubstraitRelRoundTripLoosePojoComparison(
+ "SELECT regexp_extract_custom(x, 'ab') from t", catalogReader, featureBoard);
+ assertSqlSubstraitRelRoundTripLoosePojoComparison(
+ "SELECT format_text('UPPER', x) FROM t", catalogReader, featureBoard);
+ assertSqlSubstraitRelRoundTripLoosePojoComparison(
+ "SELECT system_property_get(x) FROM t", catalogReader, featureBoard);
+ assertSqlSubstraitRelRoundTripLoosePojoComparison(
+ "SELECT safe_divide_custom(10,0) FROM t", catalogReader, featureBoard);
+ }
+
+ private static SimpleExtension.ExtensionCollection loadExtensions(
+ List yamlFunctionFiles) {
+ SimpleExtension.ExtensionCollection extensions = DefaultExtensionCatalog.DEFAULT_COLLECTION;
+ if (yamlFunctionFiles != null && !yamlFunctionFiles.isEmpty()) {
+ SimpleExtension.ExtensionCollection customExtensions =
+ SimpleExtension.load(yamlFunctionFiles);
+ extensions = extensions.merge(customExtensions);
+ }
+ return extensions;
+ }
+}
diff --git a/isthmus/src/test/resources/extensions/scalar_functions_custom.yaml b/isthmus/src/test/resources/extensions/scalar_functions_custom.yaml
new file mode 100644
index 000000000..05595eb2f
--- /dev/null
+++ b/isthmus/src/test/resources/extensions/scalar_functions_custom.yaml
@@ -0,0 +1,43 @@
+%YAML 1.2
+---
+urn: extension:substrait:functions_custom
+scalar_functions:
+ - name: "regexp_extract_custom"
+ impls:
+ - args:
+ - name: "text"
+ value: string
+ - name: "pattern"
+ value: string
+ return: string
+
+ - name: "format_text"
+ description: "Formats text based on a mode. The output is nullable if the input is."
+ impls:
+ - args:
+ - name: "mode"
+ value: string
+ - name: "input_text"
+ value: string
+ return: string
+ nullability: MIRROR
+
+ - name: "system_property_get"
+ description: "Safely gets a system property. Always returns a nullable string."
+ impls:
+ - args:
+ - name: "property_name"
+ value: string
+ return: string?
+ nullability: DECLARED_OUTPUT
+
+ - name: "safe_divide_custom"
+ description: "Performs division, returning NULL if the denominator is zero."
+ impls:
+ - args:
+ - name: "numerator"
+ value: i32
+ - name: "denominator"
+ value: i32
+ return: fp32?
+ nullability: DISCRETE