diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 072507295..bdff1b934 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -186,6 +186,11 @@ public O visit(Expression.MultiOrList expr, C context) throws E { return visitFallback(expr, context); } + @Override + public O visit(Expression.NestedList expr, C context) throws E { + return visitFallback(expr, context); + } + @Override public O visit(FieldReference expr, C context) throws E { return visitFallback(expr, context); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 42c3c5118..548040b43 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -32,6 +32,13 @@ default boolean nullable() { } } + interface Nested extends Expression { + @Value.Default + default boolean nullable() { + return false; + } + } + R accept( ExpressionVisitor visitor, C context) throws E; @@ -922,6 +929,40 @@ public R accept( } } + /** + * A nested list expression with one or more elements. + * + *

Note: This class cannot be used to construct an empty list. To create an empty list, use + * {@link ExpressionCreator#emptyList(boolean, Type)} which returns an {@link EmptyListLiteral}. + */ + @Value.Immutable + abstract class NestedList implements Nested { + public abstract List values(); + + @Value.Check + protected void check() { + assert !values().isEmpty() : "To specify an empty list, use ExpressionCreator.emptyList()"; + + assert values().stream().map(Expression::getType).distinct().count() <= 1 + : "All values in NestedList must have the same type"; + } + + @Override + public Type getType() { + return Type.withNullability(nullable()).list(values().get(0).getType()); + } + + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + + public static ImmutableExpression.NestedList.Builder builder() { + return ImmutableExpression.NestedList.builder(); + } + } + @Value.Immutable abstract class MultiOrListRecord { public abstract List values(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index adf157d7b..051bdc46e 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -281,6 +281,17 @@ public static Expression.StructLiteral struct(boolean nullable, Expression.Liter return Expression.StructLiteral.builder().nullable(nullable).addFields(values).build(); } + /** + * Creates a nested list expression with one or more elements. + * + *

Note: This class cannot be used to construct an empty list. To create an empty list, use + * {@link ExpressionCreator#emptyList(boolean, Type)} which returns an {@link + * Expression.EmptyListLiteral}. + */ + public static Expression.NestedList nestedList(boolean nullable, List values) { + return Expression.NestedList.builder().nullable(nullable).addAllValues(values).build(); + } + public static Expression.StructLiteral struct( boolean nullable, Iterable values) { return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index d64cab48c..8b8cb72fc 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -78,6 +78,8 @@ public interface ExpressionVisitor values = + expr.values().stream().map(this::toProto).collect(Collectors.toList()); + + return Expression.newBuilder() + .setNested( + Expression.Nested.newBuilder() + .setList(Expression.Nested.List.newBuilder().addAllValues(values)) + .setNullable(expr.nullable())) + .build(); + } + @Override public Expression visit(FieldReference expr, EmptyVisitationContext context) { diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 8f95cdf07..c0b8d61f5 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -197,6 +197,8 @@ public Expression from(io.substrait.proto.Expression expr) { multiOrList.getValueList().stream().map(this::from).collect(Collectors.toList())) .build(); } + case NESTED: + return from(expr.getNested()); case CAST: return ExpressionCreator.cast( protoTypeConverter.from(expr.getCast().getType()), @@ -361,6 +363,18 @@ private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.B } } + public Expression.Nested from(io.substrait.proto.Expression.Nested nested) { + switch (nested.getNestedTypeCase()) { + case LIST: + List list = + nested.getList().getValuesList().stream().map(this::from).collect(Collectors.toList()); + return ExpressionCreator.nestedList(nested.getNullable(), list); + default: + throw new UnsupportedOperationException( + "Unimplemented nested type: " + nested.getNestedTypeCase()); + } + } + public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { switch (literal.getLiteralTypeCase()) { case BOOLEAN: diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 57132a940..70d1436c9 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -348,6 +348,16 @@ public Optional visit( .build()); } + @Override + public Optional visit(Expression.NestedList expr, EmptyVisitationContext context) + throws E { + Optional> expressions = visitExprList(expr.values(), context); + + return expressions.map( + expressionList -> + Expression.NestedList.builder().from(expr).values(expressionList).build()); + } + protected Optional visitMultiOrListRecord( Expression.MultiOrListRecord multiOrListRecord, EmptyVisitationContext context) throws E { return visitExprList(multiOrListRecord.values(), context) diff --git a/core/src/test/java/io/substrait/type/proto/NestedListExpressionTest.java b/core/src/test/java/io/substrait/type/proto/NestedListExpressionTest.java new file mode 100644 index 000000000..e2c4b18b8 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/NestedListExpressionTest.java @@ -0,0 +1,94 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ImmutableExpression; +import org.junit.jupiter.api.Test; + +class NestedListExpressionTest extends TestBase { + io.substrait.expression.Expression literalExpression = + Expression.BoolLiteral.builder().value(true).build(); + Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42)); + + @Test + void rejectNestedListWithElementsOfDifferentTypes() { + ImmutableExpression.NestedList.Builder builder = + Expression.NestedList.builder().addValues(literalExpression).addValues(b.i32(12)); + assertThrows(AssertionError.class, builder::build); + } + + @Test + void acceptNestedListWithElementsOfSameType() { + ImmutableExpression.NestedList.Builder builder = + Expression.NestedList.builder().addValues(nonLiteralExpression).addValues(b.i32(12)); + assertDoesNotThrow(builder::build); + + io.substrait.relation.Project project = + io.substrait.relation.Project.builder() + .addExpressions(builder.build()) + .input(b.emptyScan()) + .build(); + verifyRoundTrip(project); + } + + @Test + void rejectEmptyNestedListTest() { + ImmutableExpression.NestedList.Builder builder = Expression.NestedList.builder(); + assertThrows(AssertionError.class, builder::build); + } + + @Test + void literalNestedListTest() { + Expression.NestedList literalNestedList = + Expression.NestedList.builder() + .addValues(literalExpression) + .addValues(literalExpression) + .build(); + + io.substrait.relation.Project project = + io.substrait.relation.Project.builder() + .addExpressions(literalNestedList) + .input(b.emptyScan()) + .build(); + + verifyRoundTrip(project); + } + + @Test + void literalNullableNestedListTest() { + Expression.NestedList literalNestedList = + Expression.NestedList.builder() + .addValues(literalExpression) + .addValues(literalExpression) + .nullable(true) + .build(); + + io.substrait.relation.Project project = + io.substrait.relation.Project.builder() + .addExpressions(literalNestedList) + .input(b.emptyScan()) + .build(); + + verifyRoundTrip(project); + } + + @Test + void nonLiteralNestedListTest() { + Expression.NestedList nonLiteralNestedList = + Expression.NestedList.builder() + .addValues(nonLiteralExpression) + .addValues(nonLiteralExpression) + .build(); + + io.substrait.relation.Project project = + io.substrait.relation.Project.builder() + .addExpressions(nonLiteralNestedList) + .input(b.emptyScan()) + .build(); + + verifyRoundTrip(project); + } +} diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index 71de9a7d5..0d25c5141 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -1,5 +1,6 @@ package io.substrait.examples.util; +import io.substrait.expression.Expression; import io.substrait.expression.Expression.BinaryLiteral; import io.substrait.expression.Expression.BoolLiteral; import io.substrait.expression.Expression.Cast; @@ -256,6 +257,12 @@ public String visit(MultiOrList expr, EmptyVisitationContext context) throws Run return sb.toString(); } + @Override + public String visit(Expression.NestedList expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; + } + @Override public String visit(FieldReference expr, EmptyVisitationContext context) throws RuntimeException { StringBuilder sb = new StringBuilder("FieldRef#"); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 181986289..69349eec3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -179,6 +179,11 @@ public Rel visit(org.apache.calcite.rel.core.Project project) { .map(this::toExpression) .collect(java.util.stream.Collectors.toList()); + // if there are no input fields, no remap is necessary + if (project.getInput().getRowType().getFieldCount() == 0) { + return Project.builder().expressions(expressions).input(apply(project.getInput())).build(); + } + // todo: eliminate excessive projects. This should be done by converting rexinputrefs to remaps. return Project.builder() .remap( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 3406de7de..247d4b9ce 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -141,7 +141,8 @@ public static List defaults(TypeConverter typeConverter) { CallConverters.CASE, CallConverters.CAST.apply(typeConverter), CallConverters.REINTERPRET.apply(typeConverter), - new LiteralConstructorConverter(typeConverter)); + new SqlArrayValueConstructorCallConverter(typeConverter), + new SqlMapValueConstructorCallConverter()); } public interface SimpleCallConverter extends CallConverter { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 2b8052889..6c0c5ee7a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -320,6 +320,24 @@ public RexNode visit(Expression.ListLiteral expr, Context context) throws Runtim return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args); } + @Override + public RexNode visit(Expression.NestedList expr, Context context) { + List args = + expr.values().stream().map(e -> e.accept(this, context)).collect(Collectors.toList()); + + // to preserve NestedList nullability + RelDataType elementType; + if (args.isEmpty()) { + throw new IllegalStateException("NestedList must have at least 1 element"); + } else { + elementType = args.get(0).getType(); + } + RelDataType nestedListType = typeFactory.createArrayType(elementType, -1); + nestedListType = typeFactory.createTypeWithNullability(nestedListType, expr.nullable()); + + return rexBuilder.makeCall(nestedListType, SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args); + } + @Override public RexNode visit(Expression.EmptyListLiteral expr, Context context) throws RuntimeException { RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/SqlArrayValueConstructorCallConverter.java similarity index 55% rename from isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java rename to isthmus/src/main/java/io/substrait/isthmus/expression/SqlArrayValueConstructorCallConverter.java index a0f6b88d1..c8805a901 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/SqlArrayValueConstructorCallConverter.java @@ -5,23 +5,21 @@ import io.substrait.isthmus.CallConverter; import io.substrait.isthmus.TypeConverter; import io.substrait.type.Type; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.function.Function; +import java.util.stream.Collectors; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlArrayValueConstructor; -import org.apache.calcite.sql.fun.SqlMapValueConstructor; -public class LiteralConstructorConverter implements CallConverter { +public class SqlArrayValueConstructorCallConverter implements CallConverter { private final TypeConverter typeConverter; - public LiteralConstructorConverter(TypeConverter typeConverter) { + public SqlArrayValueConstructorCallConverter(TypeConverter typeConverter) { this.typeConverter = typeConverter; } @@ -33,34 +31,28 @@ public Optional convert( return call.getOperands().isEmpty() ? toEmptyListLiteral(call) : toNonEmptyListLiteral(call, topLevelConverter); - } else if (operator instanceof SqlMapValueConstructor) { - return toMapLiteral(call, topLevelConverter); } return Optional.empty(); } - private Optional toMapLiteral( - RexCall call, Function topLevelConverter) { - List literals = - call.operands.stream() - .map(t -> ((Expression.Literal) topLevelConverter.apply(t))) - .collect(java.util.stream.Collectors.toList()); - Map items = new HashMap<>(); - assert literals.size() % 2 == 0; - for (int i = 0; i < literals.size(); i += 2) { - items.put(literals.get(i), literals.get(i + 1)); - } - return Optional.of(ExpressionCreator.map(false, items)); - } - private Optional toNonEmptyListLiteral( RexCall call, Function topLevelConverter) { - return Optional.of( - ExpressionCreator.list( - call.getType().isNullable(), - call.operands.stream() - .map(t -> ((Expression.Literal) topLevelConverter.apply(t))) - .collect(java.util.stream.Collectors.toList()))); + List expressions = + call.operands.stream().map(topLevelConverter).collect(Collectors.toList()); + + // Check if all operands are actually literals + if (expressions.stream().allMatch(e -> e instanceof Expression.Literal)) { + return Optional.of( + ExpressionCreator.list( + call.getType().isNullable(), + expressions.stream().map(e -> (Expression.Literal) e).collect(Collectors.toList()))); + } else { + return Optional.of( + Expression.NestedList.builder() + .nullable(call.getType().isNullable()) + .values(expressions) + .build()); + } } private Optional toEmptyListLiteral(RexCall call) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java new file mode 100644 index 000000000..8cf4958d8 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java @@ -0,0 +1,43 @@ +package io.substrait.isthmus.expression; + +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.isthmus.CallConverter; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlMapValueConstructor; + +public class SqlMapValueConstructorCallConverter implements CallConverter { + + SqlMapValueConstructorCallConverter() {} + + @Override + public Optional convert( + RexCall call, Function topLevelConverter) { + SqlOperator operator = call.getOperator(); + if (operator instanceof SqlMapValueConstructor) { + return toMapLiteral(call, topLevelConverter); + } + return Optional.empty(); + } + + private Optional toMapLiteral( + RexCall call, Function topLevelConverter) { + List literals = + call.operands.stream() + .map(t -> ((Expression.Literal) topLevelConverter.apply(t))) + .collect(java.util.stream.Collectors.toList()); + Map items = new HashMap<>(); + assert literals.size() % 2 == 0; + for (int i = 0; i < literals.size(); i += 2) { + items.put(literals.get(i), literals.get(i + 1)); + } + return Optional.of(ExpressionCreator.map(false, items)); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java new file mode 100644 index 000000000..e0c4b8023 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java @@ -0,0 +1,166 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.protobuf.ByteString; +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.ImmutableExpression; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.apache.calcite.rel.RelNode; +import org.junit.jupiter.api.Test; + +class NestedExpressionsTest extends PlanTestBase { + + protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection = + DefaultExtensionCatalog.DEFAULT_COLLECTION; + protected SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection); + SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + + Expression literalExpression = Expression.BoolLiteral.builder().value(true).build(); + Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42)); + Expression.ScalarFunctionInvocation nonLiteralExpression2 = b.add(b.i32(3), b.i32(4)); + + final List tableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN, N.STRING); + final Rel commonTable = + b.namedScan(List.of("example"), List.of("a", "b", "c", "d", "e"), tableType); + final Rel emptyTable = b.emptyScan(); + + Expression fieldRef1 = b.fieldReference(commonTable, 2); + Expression fieldRef2 = b.fieldReference(commonTable, 4); + + @Test + void nestedListWithLiteralsTest() { + List expressionList = new ArrayList<>(); + Expression.NestedList literalNestedList = + Expression.NestedList.builder() + .addValues(literalExpression) + .addValues(literalExpression) + .build(); + expressionList.add(literalNestedList); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + + RelNode relNode = substraitToCalcite.convert(project); // substrait rel to calcite + Rel substraitRel = SubstraitRelVisitor.convert(relNode, extensions); // calcite to substrait + Expression project2 = ((Project) substraitRel).getExpressions().get(0); + assertEquals(ImmutableExpression.ListLiteral.class, project2.getClass()); + Expression.ListLiteral listLiteral = (Expression.ListLiteral) project2; + assertEquals(literalNestedList.values(), listLiteral.values()); + } + + @Test + void nestedListWithNonLiteralsTest() { + List expressionList = new ArrayList<>(); + + Expression.NestedList nonLiteralNestedList = + Expression.NestedList.builder() + .addValues(nonLiteralExpression) + .addValues(nonLiteralExpression2) + .build(); + expressionList.add(nonLiteralNestedList); + + Project project = + Project.builder() + .expressions(expressionList) + .input(commonTable) + // project only the nestedList expression and exclude the 5 input columns + .remap(Rel.Remap.of(Collections.singleton(5))) + .build(); + + assertFullRoundTrip(project); + } + + @Test + void nestedListWithFieldReferenceTest() { + Expression.NestedList nestedListWithField = + Expression.NestedList.builder().addValues(fieldRef1).addValues(fieldRef2).build(); + + List expressionList = new ArrayList<>(); + expressionList.add(nestedListWithField); + + Project project = + Project.builder() + .expressions(expressionList) + .input(commonTable) + .remap(Rel.Remap.of(Collections.singleton(5))) + .build(); + + assertFullRoundTrip(project); + } + + @Test + void nestedListWithStringLiteralsTest() { + Expression.NestedList nestedList = + Expression.NestedList.builder().addValues(b.str("xzy")).addValues(b.str("abc")).build(); + + Rel project = Project.builder().expressions(List.of(nestedList)).input(emptyTable).build(); + + RelNode relNode = substraitToCalcite.convert(project); // substrait rel to calcite + Rel substraitRel = SubstraitRelVisitor.convert(relNode, extensions); // calcite to substrait + Expression project2 = ((Project) substraitRel).getExpressions().get(0); + assertEquals(ImmutableExpression.ListLiteral.class, project2.getClass()); + Expression.ListLiteral listLiteral = (Expression.ListLiteral) project2; + assertEquals(nestedList.values(), listLiteral.values()); + } + + @Test + void nestedListWithBinaryLiteralTest() { + Expression binaryLiteral = + Expression.BinaryLiteral.builder() + .value(ByteString.copyFrom(new byte[] {0x01, 0x02})) + .build(); + + Expression.NestedList nestedList = + Expression.NestedList.builder().addValues(binaryLiteral).addValues(binaryLiteral).build(); + + Rel project = Project.builder().expressions(List.of(nestedList)).input(emptyTable).build(); + + RelNode relNode = substraitToCalcite.convert(project); // substrait rel to calcite + Rel substraitRel = SubstraitRelVisitor.convert(relNode, extensions); // calcite to substrait + Expression project2 = ((Project) substraitRel).getExpressions().get(0); + assertEquals(ImmutableExpression.ListLiteral.class, project2.getClass()); + Expression.ListLiteral listLiteral = (Expression.ListLiteral) project2; + assertEquals(nestedList.values(), listLiteral.values()); + } + + @Test + void nestedListWithSingleLiteralTest() { + List expressionList = new ArrayList<>(); + Expression.NestedList literalNestedList = + Expression.NestedList.builder().addValues(literalExpression).build(); + expressionList.add(literalNestedList); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + + RelNode relNode = substraitToCalcite.convert(project); // substrait rel to calcite + Rel substraitRel = SubstraitRelVisitor.convert(relNode, extensions); // calcite to substrait + Expression project2 = ((Project) substraitRel).getExpressions().get(0); + assertEquals(ImmutableExpression.ListLiteral.class, project2.getClass()); + Expression.ListLiteral listLiteral = (Expression.ListLiteral) project2; + assertEquals(literalNestedList.values(), listLiteral.values()); + } + + @Test + void nullableNestedListTest() { + List expressionList = new ArrayList<>(); + Expression.NestedList literalNestedList = + Expression.NestedList.builder() + .addValues(nonLiteralExpression) + .addValues(nonLiteralExpression2) + .nullable(true) + .build(); + expressionList.add(literalNestedList); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + + assertFullRoundTrip(project); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProjectTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProjectTest.java new file mode 100644 index 000000000..d6d7868b7 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/ProjectTest.java @@ -0,0 +1,18 @@ +package io.substrait.isthmus; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; +import org.junit.jupiter.api.Test; + +class ProjectTest extends PlanTestBase { + final SubstraitBuilder b = new SubstraitBuilder(extensions); + final Rel emptyTable = b.emptyScan(); + + @Test + void avoidProjectRemapOnEmptyInput() { + Rel projection = + Project.builder().input(emptyTable).addExpressions(b.add(b.i32(1), b.i32(2))).build(); + assertFullRoundTrip(projection); + } +}