Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ public O visit(Expression.MultiOrList expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.NestedList expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(FieldReference expr, C context) throws E {
return visitFallback(expr, context);
Expand Down
27 changes: 27 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ default boolean nullable() {
}
}

interface Nested extends Expression {
@Value.Default
default boolean nullable() {
return false;
}
}

<R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E;

Expand Down Expand Up @@ -922,6 +929,26 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

@Value.Immutable
abstract class NestedList implements Nested {
public abstract List<Expression> values();

@Override
public Type getType() {
return Type.withNullability(nullable()).list(values().get(0).getType());
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}

public static ImmutableExpression.NestedList.Builder builder() {
return ImmutableExpression.NestedList.builder();
}
}

@Value.Immutable
abstract class MultiOrListRecord {
public abstract List<Expression> values();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ public static Expression.StructLiteral struct(boolean nullable, Expression.Liter
return Expression.StructLiteral.builder().nullable(nullable).addFields(values).build();
}

public static Expression.NestedList nestedList(boolean nullable, List<Expression> values) {
return Expression.NestedList.builder().nullable(nullable).addAllValues(values).build();
}

public static Expression.StructLiteral struct(
boolean nullable, Iterable<? extends Expression.Literal> values) {
return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr

R visit(Expression.MultiOrList expr, C context) throws E;

R visit(Expression.NestedList expr, C context) throws E;

R visit(FieldReference expr, C context) throws E;

R visit(Expression.SetPredicate expr, C context) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,26 @@ public Expression visit(
.build();
}

private Expression nested(Consumer<Expression.Nested.Builder> consumer) {
Expression.Nested.Builder builder = Expression.Nested.newBuilder();
builder.setNullable(builder.getNullable());
consumer.accept(builder);
return Expression.newBuilder().setNested(builder).build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.NestedList expr, EmptyVisitationContext context)
throws RuntimeException {
return nested(
bldr -> {
List<Expression> values =
expr.values().stream().map(this::toProto).collect(Collectors.toList());
bldr.setNullable(expr.nullable())
.setList(Expression.Nested.List.newBuilder().addAllValues(values));
});
}

@Override
public Expression visit(FieldReference expr, EmptyVisitationContext context) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ public Expression from(io.substrait.proto.Expression expr) {
multiOrList.getValueList().stream().map(this::from).collect(Collectors.toList()))
.build();
}
case NESTED:
return from(expr.getNested());
case CAST:
return ExpressionCreator.cast(
protoTypeConverter.from(expr.getCast().getType()),
Expand Down Expand Up @@ -361,6 +363,15 @@ private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.B
}
}

public Expression.Nested from(io.substrait.proto.Expression.Nested nested) {
if (nested.getNestedTypeCase() == io.substrait.proto.Expression.Nested.NestedTypeCase.LIST) {
List<Expression> list =
nested.getList().getValuesList().stream().map(this::from).collect(Collectors.toList());
return ExpressionCreator.nestedList(nested.getNullable(), list);
}
return null;
}

public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
switch (literal.getLiteralTypeCase()) {
case BOOLEAN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ public Optional<Expression> visit(
.build());
}

@Override
public Optional<Expression> visit(Expression.NestedList expr, EmptyVisitationContext context)
throws E {
Optional<List<Expression>> expressions = visitExprList(expr.values(), context);

return expressions.map(
expressionList ->
Expression.NestedList.builder().from(expr).values(expressionList).build());
}

protected Optional<Expression.MultiOrListRecord> visitMultiOrListRecord(
Expression.MultiOrListRecord multiOrListRecord, EmptyVisitationContext context) throws E {
return visitExprList(multiOrListRecord.values(), context)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.substrait.type.proto;

import io.substrait.TestBase;
import io.substrait.expression.Expression;
import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.Test;

class NestedListExpressionTest extends TestBase {
io.substrait.expression.Expression literalExpression =
Expression.BoolLiteral.builder().value(true).build();
Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42));

@Test
void literalNestedListTest() {
List<Expression> expressionList = new ArrayList<>();
Expression.NestedList literalNestedList =
Expression.NestedList.builder()
.addValues(literalExpression)
.addValues(literalExpression)
.build();
expressionList.add(literalNestedList);

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.expressions(expressionList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}

@Test
void nonLiteralNestedListTest() {
List<Expression> expressionList = new ArrayList<>();

Expression.NestedList nonLiteralNestedList =
Expression.NestedList.builder()
.addValues(nonLiteralExpression)
.addValues(nonLiteralExpression)
.build();
expressionList.add(nonLiteralNestedList);

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.expressions(expressionList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.examples.util;

import io.substrait.expression.Expression;
import io.substrait.expression.Expression.BinaryLiteral;
import io.substrait.expression.Expression.BoolLiteral;
import io.substrait.expression.Expression.Cast;
Expand Down Expand Up @@ -256,6 +257,12 @@ public String visit(MultiOrList expr, EmptyVisitationContext context) throws Run
return sb.toString();
}

@Override
public String visit(Expression.NestedList expr, EmptyVisitationContext context)
throws RuntimeException {
return "<NestedList>";
}

@Override
public String visit(FieldReference expr, EmptyVisitationContext context) throws RuntimeException {
StringBuilder sb = new StringBuilder("FieldRef#");
Expand Down
19 changes: 19 additions & 0 deletions isthmus/src/main/java/io/substrait/isthmus/NestedFunctions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.substrait.isthmus;

import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.ReturnTypes;

/** Substrait-specific extension function for the Expression Nested List type */
public class NestedFunctions {

public static final SqlFunction NESTED_LIST =
new SqlFunction(
"nested_list",
SqlKind.OTHER_FUNCTION,
ReturnTypes.BOOLEAN,
null,
null,
SqlFunctionCategory.USER_DEFINED_FUNCTION);
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ public Rel visit(org.apache.calcite.rel.core.Project project) {
.map(this::toExpression)
.collect(java.util.stream.Collectors.toList());

// if there is no input fields, don’t put a remapping on it

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// if there is no input fields, don’t put a remapping on it
// if there are no input fields, don’t put a remapping on it

nit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor/pedantic: small rewording + avoiding extra whitespace

Suggested change
// if there is no input fields, don’t put a remapping on it
// if there are no input fields, no remap is necessary

if (project.getInput().getRowType().getFieldCount() == 0) {
return Project.builder().expressions(expressions).input(apply(project.getInput())).build();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks good, but is a bit out of place in this PR. From experimenting locally it seems that one of your tests fails locally with it because a pointless remap is introduced.

I do think it's a good change, and we should keep it in this PR, but could you add an explicit test for this features along the lines of avoidProjectRemapOnEmptyInput() to ProjectRelRoundtripTest.

// todo: eliminate excessive projects. This should be done by converting rexinputrefs to remaps.
return Project.builder()
.remap(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.isthmus.calcite;

import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.NestedFunctions;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
Expand Down Expand Up @@ -34,7 +35,8 @@ public class SubstraitOperatorTable implements SqlOperatorTable {
AggregateFunctions.MIN,
AggregateFunctions.AVG,
AggregateFunctions.SUM,
AggregateFunctions.SUM0));
AggregateFunctions.SUM0,
NestedFunctions.NESTED_LIST));

// SQL Kinds for which Substrait specific operators are provided
private static final Set<SqlKind> OVERRIDE_KINDS =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
CallConverters.CASE,
CallConverters.CAST.apply(typeConverter),
CallConverters.REINTERPRET.apply(typeConverter),
new LiteralConstructorConverter(typeConverter));
new LiteralConstructorConverter(typeConverter),
new NestedExpressionConverter());
}

public interface SimpleCallConverter extends CallConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.NestedFunctions;
import io.substrait.isthmus.SubstraitRelNodeConverter;
import io.substrait.isthmus.SubstraitRelNodeConverter.Context;
import io.substrait.isthmus.TypeConverter;
Expand Down Expand Up @@ -320,6 +321,13 @@ public RexNode visit(Expression.ListLiteral expr, Context context) throws Runtim
return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args);
}

@Override
public RexNode visit(Expression.NestedList expr, Context context) {
List<RexNode> args =
expr.values().stream().map(e -> e.accept(this, context)).collect(Collectors.toList());
return rexBuilder.makeCall(NestedFunctions.NESTED_LIST, args);
}

@Override
public RexNode visit(Expression.EmptyListLiteral expr, Context context) throws RuntimeException {
RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package io.substrait.isthmus.expression;

import io.substrait.expression.Expression;
import io.substrait.isthmus.CallConverter;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;

public class NestedExpressionConverter implements CallConverter {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its a reasonable choice not to have this be a nested class in CallConverters. That said, could you rename this to NestedExpressionCallConverter to make differentiate from the other things we call converters.


public NestedExpressionConverter() {}

@Override
public Optional<Expression> convert(
RexCall call, Function<RexNode, Expression> topLevelConverter) {

if (!call.getOperator().getName().equals("nested_list")) {
return Optional.empty();
}

List<Expression> values =
call.operands.stream().map(topLevelConverter).collect(Collectors.toList());

return Optional.of(
Expression.NestedList.builder()
.nullable(call.getType().isNullable())
.values(values)
.build());
}
}
Loading