Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ public O visit(Expression.MultiOrList expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.NestedList expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(FieldReference expr, C context) throws E {
return visitFallback(expr, context);
Expand Down
41 changes: 41 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ default boolean nullable() {
}
}

interface Nested extends Expression {
@Value.Default
default boolean nullable() {
return false;
}
}

<R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E;

Expand Down Expand Up @@ -922,6 +929,40 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

/**
* A nested list expression with one or more elements.
*
* <p>Note: This class cannot be used to construct an empty list. To create an empty list, use
* {@link ExpressionCreator#emptyList(boolean, Type)} which returns an {@link EmptyListLiteral}.
*/
@Value.Immutable
abstract class NestedList implements Nested {
public abstract List<Expression> values();

@Value.Check
protected void check() {
assert !values().isEmpty() : "To specify an empty list, use ExpressionCreator.emptyList()";

assert values().stream().map(Expression::getType).distinct().count() <= 1
: "All values in NestedList must have the same type";
}

@Override
public Type getType() {
return Type.withNullability(nullable()).list(values().get(0).getType());
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}

public static ImmutableExpression.NestedList.Builder builder() {
return ImmutableExpression.NestedList.builder();
}
}

@Value.Immutable
abstract class MultiOrListRecord {
public abstract List<Expression> values();
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,17 @@ public static Expression.StructLiteral struct(boolean nullable, Expression.Liter
return Expression.StructLiteral.builder().nullable(nullable).addFields(values).build();
}

/**
* Creator a nested list expression with one or more elements.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Creator a nested list expression with one or more elements.
* Creates a nested list expression with one or more elements.

*
* <p>Note: This class cannot be used to construct an empty list. To create an empty list, use
* {@link ExpressionCreator#emptyList(boolean, Type)} which returns an {@link
* Expression.EmptyListLiteral}.
*/
public static Expression.NestedList nestedList(boolean nullable, List<Expression> values) {
return Expression.NestedList.builder().nullable(nullable).addAllValues(values).build();
}

public static Expression.StructLiteral struct(
boolean nullable, Iterable<? extends Expression.Literal> values) {
return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr

R visit(Expression.MultiOrList expr, C context) throws E;

R visit(Expression.NestedList expr, C context) throws E;

R visit(FieldReference expr, C context) throws E;

R visit(Expression.SetPredicate expr, C context) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,22 @@ public Expression visit(
.build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.NestedList expr, EmptyVisitationContext context)
throws RuntimeException {

List<Expression> values =
expr.values().stream().map(this::toProto).collect(Collectors.toList());

return Expression.newBuilder()
.setNested(
Expression.Nested.newBuilder()
.setList(Expression.Nested.List.newBuilder().addAllValues(values))
.setNullable(expr.nullable()))
.build();
}

@Override
public Expression visit(FieldReference expr, EmptyVisitationContext context) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ public Expression from(io.substrait.proto.Expression expr) {
multiOrList.getValueList().stream().map(this::from).collect(Collectors.toList()))
.build();
}
case NESTED:
return from(expr.getNested());
case CAST:
return ExpressionCreator.cast(
protoTypeConverter.from(expr.getCast().getType()),
Expand Down Expand Up @@ -361,6 +363,17 @@ private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.B
}
}

public Expression.Nested from(io.substrait.proto.Expression.Nested nested) {
switch (nested.getNestedTypeCase()) {
case LIST:
List<Expression> list =
nested.getList().getValuesList().stream().map(this::from).collect(Collectors.toList());
return ExpressionCreator.nestedList(nested.getNullable(), list);
default:
throw new IllegalStateException("Unimplemented nested type: " + nested.getNestedTypeCase());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: We can use UnsupportedOperationException here, which better matches your error message well 🙂

}
}

public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
switch (literal.getLiteralTypeCase()) {
case BOOLEAN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ public Optional<Expression> visit(
.build());
}

@Override
public Optional<Expression> visit(Expression.NestedList expr, EmptyVisitationContext context)
throws E {
Optional<List<Expression>> expressions = visitExprList(expr.values(), context);

return expressions.map(
expressionList ->
Expression.NestedList.builder().from(expr).values(expressionList).build());
}

protected Optional<Expression.MultiOrListRecord> visitMultiOrListRecord(
Expression.MultiOrListRecord multiOrListRecord, EmptyVisitationContext context) throws E {
return visitExprList(multiOrListRecord.values(), context)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package io.substrait.type.proto;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.TestBase;
import io.substrait.expression.Expression;
import io.substrait.expression.ImmutableExpression;
import org.junit.jupiter.api.Test;

class NestedListExpressionTest extends TestBase {
io.substrait.expression.Expression literalExpression =
Expression.BoolLiteral.builder().value(true).build();
Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42));

@Test
void DifferentTypedLiteralsNestedListTest() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void DifferentTypedLiteralsNestedListTest() {
void rejectNestedListWithElementsOfDifferentTypes() {

minor suggestion: if we include the expected behaviour in the test name, it's easier to see at at glance what its testing. Similar suggestions for test below.

ImmutableExpression.NestedList.Builder builder =
Expression.NestedList.builder().addValues(literalExpression).addValues(b.i32(12));
assertThrows(AssertionError.class, builder::build);
}

@Test
void SameTypedLiteralsNestedListTest() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void SameTypedLiteralsNestedListTest() {
void acceptNestedListWithElementsOfSameType() {

Beyond the name suggestion for this test, it is a bit redundant because it doesn't really verify anything differently than your normal tests below do.

ImmutableExpression.NestedList.Builder builder =
Expression.NestedList.builder().addValues(nonLiteralExpression).addValues(b.i32(12));
assertDoesNotThrow(builder::build);

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(builder.build())
.input(b.emptyScan())
.build();
verifyRoundTrip(project);
}

@Test
void EmptyNestedListTest() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void EmptyNestedListTest() {
void rejectEmptyNestedListTest() {

ImmutableExpression.NestedList.Builder builder = Expression.NestedList.builder();
assertThrows(AssertionError.class, builder::build);
}

@Test
void literalNestedListTest() {
Expression.NestedList literalNestedList =
Expression.NestedList.builder()
.addValues(literalExpression)
.addValues(literalExpression)
.build();

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(literalNestedList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}

@Test
void literalNullableNestedListTest() {
Expression.NestedList literalNestedList =
Expression.NestedList.builder()
.addValues(literalExpression)
.addValues(literalExpression)
.nullable(true)
.build();

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(literalNestedList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}

@Test
void nonLiteralNestedListTest() {
Expression.NestedList nonLiteralNestedList =
Expression.NestedList.builder()
.addValues(nonLiteralExpression)
.addValues(nonLiteralExpression)
.build();

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(nonLiteralNestedList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.examples.util;

import io.substrait.expression.Expression;
import io.substrait.expression.Expression.BinaryLiteral;
import io.substrait.expression.Expression.BoolLiteral;
import io.substrait.expression.Expression.Cast;
Expand Down Expand Up @@ -256,6 +257,12 @@ public String visit(MultiOrList expr, EmptyVisitationContext context) throws Run
return sb.toString();
}

@Override
public String visit(Expression.NestedList expr, EmptyVisitationContext context)
throws RuntimeException {
return "<NestedList>";
}

@Override
public String visit(FieldReference expr, EmptyVisitationContext context) throws RuntimeException {
StringBuilder sb = new StringBuilder("FieldRef#");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.substrait.isthmus;

import static java.util.Objects.requireNonNull;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.fun.SqlMultisetValueConstructor;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;

/**
* Substrait-specific constructor to map back to the Expression NestedList type in Substrait. This
* constructor creates a special type of SqlKind.ARRAY_VALUE_CONSTRUCTOR for lists that store
* non-literal expressions.
Comment on lines +14 to +15

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually and by looking at the tests (NestedListWithJustLiteralsTest), it seems this can indeed handle literal expressions, can you double check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can handle both literals and non-literals. I'll update the comment to avoid the confusion

*/
public class NestedListConstructor extends SqlMultisetValueConstructor {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you picked SqlMultisetValueConstructor instead of the more natural SqlArrayValueConstructor because CallConverters.java#L144 would then match on your NestedList and treat it as the regular SqlArrayValueConstructor (then invoke LiteralConstructorConverter.java#L32 which is not what we want here).

SqlMultisetValueConstructor is conceptually wrong as a multiset is radically different from an array/list in standard SQL, on top of that I imagine that tomorrow you might want to support something similar and hit the same issue you are trying to avoid here by using this class in the first place.

The impedance mismatch is that in SQL (and Calcite), arrays and lists are technically the same entity, while IIRC in Substrait they are treated as different entities (@benbellick can you confirm this?).

By looking at LiteralConstructorConverter, there is an implicit assumption that arrays store only literals, we go down that route without checking if elements in the array are really literals (LiteralConstructorConverter.java#L62).

It's probably enough to change LiteralConstructorConverter::toNonEmptyListLiteral to something like this (haven't tested it):

private Optional<Expression> toNonEmptyListLiteral(
      RexCall call, Function<RexNode, Expression> topLevelConverter) {
    List<Expression> expressions = call.operands.stream()
        .map(topLevelConverter)
        .collect(Collectors.toList());

    // Check if all operands are actually literals
    if (expressions.stream().allMatch(e -> e instanceof Expression.Literal)) {
      return Optional.of(ExpressionCreator.list(
          call.getType().isNullable(),
          expressions.stream()
              .map(e -> (Expression.Literal) e)
              .collect(Collectors.toList())));
    }

    return Optional.empty();
  }

I suggest to extend SqlArrayValueConstructor (which, I know, extends SqlMultisetValueConstructor, but still), then fix LiteralConstructorConverter as suggested, so that we can continue with NestedExpressionConverter which comes just after, and we should be good

Copy link
Contributor Author

@gord02 gord02 Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I am understanding it correctly, let me know your thoughts on the following scenarios: We want to ensure that the roundtrip of both a nested list with literals and non-literals are both returned to a nested list. If the literalConstructorConverter is run first on a list of just literals, then it would pass and then wouldn't be mapped back to a nested list. In the other case, where the nestedExpressionConverter is run first, the literal lists that were originally not a NestedList would become a nested list. Does the above account for this or is the difference not important?

Also, is there a way to meaningfully extend the SqlArrayValueConstructor class? Its definition is bare with just a constructor to its parent type.


public NestedListConstructor() {
super("NESTEDLIST", SqlKind.ARRAY_VALUE_CONSTRUCTOR);
}

@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
RelDataType type =
getComponentType(opBinding.getTypeFactory(), opBinding.collectOperandTypes());
requireNonNull(type, "inferred array element type");

// explicit cast elements to component type if they are not same
SqlValidatorUtil.adjustTypeForArrayConstructor(type, opBinding);

return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ public Rel visit(org.apache.calcite.rel.core.Project project) {
.map(this::toExpression)
.collect(java.util.stream.Collectors.toList());

// if there is no input fields, don’t put a remapping on it

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// if there is no input fields, don’t put a remapping on it
// if there are no input fields, don’t put a remapping on it

nit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor/pedantic: small rewording + avoiding extra whitespace

Suggested change
// if there is no input fields, don’t put a remapping on it
// if there are no input fields, no remap is necessary

if (project.getInput().getRowType().getFieldCount() == 0) {
return Project.builder().expressions(expressions).input(apply(project.getInput())).build();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks good, but is a bit out of place in this PR. From experimenting locally it seems that one of your tests fails locally with it because a pointless remap is introduced.

I do think it's a good change, and we should keep it in this PR, but could you add an explicit test for this features along the lines of avoidProjectRemapOnEmptyInput() to ProjectRelRoundtripTest.

// todo: eliminate excessive projects. This should be done by converting rexinputrefs to remaps.
return Project.builder()
.remap(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.isthmus.calcite;

import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.NestedListConstructor;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
Expand Down Expand Up @@ -36,6 +37,8 @@ public class SubstraitOperatorTable implements SqlOperatorTable {
AggregateFunctions.SUM,
AggregateFunctions.SUM0));

public static NestedListConstructor NESTED_LIST_CONSTRUCTOR = new NestedListConstructor();

// SQL Kinds for which Substrait specific operators are provided
private static final Set<SqlKind> OVERRIDE_KINDS =
EnumSet.copyOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
CallConverters.CASE,
CallConverters.CAST.apply(typeConverter),
CallConverters.REINTERPRET.apply(typeConverter),
new LiteralConstructorConverter(typeConverter));
new LiteralConstructorConverter(typeConverter),
new NestedExpressionConverter());
}

public interface SimpleCallConverter extends CallConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.substrait.isthmus.SubstraitRelNodeConverter;
import io.substrait.isthmus.SubstraitRelNodeConverter.Context;
import io.substrait.isthmus.TypeConverter;
import io.substrait.isthmus.calcite.SubstraitOperatorTable;
import io.substrait.type.StringTypeVisitor;
import io.substrait.type.Type;
import io.substrait.util.DecimalUtil;
Expand Down Expand Up @@ -320,6 +321,13 @@ public RexNode visit(Expression.ListLiteral expr, Context context) throws Runtim
return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args);
}

@Override
public RexNode visit(Expression.NestedList expr, Context context) {
List<RexNode> args =
expr.values().stream().map(e -> e.accept(this, context)).collect(Collectors.toList());
return rexBuilder.makeCall(SubstraitOperatorTable.NESTED_LIST_CONSTRUCTOR, args);
}

@Override
public RexNode visit(Expression.EmptyListLiteral expr, Context context) throws RuntimeException {
RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType());
Expand Down
Loading