diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 66a0003cf..f70d5a3fe 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E { return visitFallback(expr, context); } + @Override + public O visit(Expression.UserDefinedAnyLiteral expr, C context) throws E { + return visitFallback(expr, context); + } + + @Override + public O visit(Expression.UserDefinedStructLiteral expr, C context) throws E { + return visitFallback(expr, context); + } + @Override public O visit(Expression.NestedStruct expr, C context) throws E { return visitFallback(expr, context); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 732ee8007..483a627bf 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -693,21 +693,94 @@ public R accept( } } + /** + * Base interface for user-defined literals. + * + *

User-defined literals can be encoded in one of two ways as per the Substrait spec: + * + *

+ */ + interface UserDefinedLiteral extends Literal { + String urn(); + + String name(); + + List typeParameters(); + } + + /** + * User-defined literal with value encoded as {@link com.google.protobuf.Any}. + * + *

This encoding allows for arbitrary binary data to be stored in the literal value. + */ @Value.Immutable - abstract class UserDefinedLiteral implements Literal { - public abstract ByteString value(); + abstract class UserDefinedAnyLiteral implements UserDefinedLiteral { + @Override + public abstract String urn(); + + @Override + public abstract String name(); + + @Override + public abstract List typeParameters(); + + public abstract com.google.protobuf.Any value(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); + } + + public static ImmutableExpression.UserDefinedAnyLiteral.Builder builder() { + return ImmutableExpression.UserDefinedAnyLiteral.builder(); + } + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + + /** + * User-defined literal with value encoded as {@link + * io.substrait.proto.Expression.Literal.Struct}. + * + *

This encoding uses a structured list of fields to represent the literal value. + */ + @Value.Immutable + abstract class UserDefinedStructLiteral implements UserDefinedLiteral { + @Override public abstract String urn(); + @Override public abstract String name(); @Override - public Type getType() { - return Type.withNullability(nullable()).userDefined(urn(), name()); + public abstract List typeParameters(); + + public abstract List fields(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); } - public static ImmutableExpression.UserDefinedLiteral.Builder builder() { - return ImmutableExpression.UserDefinedLiteral.builder(); + public static ImmutableExpression.UserDefinedStructLiteral.Builder builder() { + return ImmutableExpression.UserDefinedStructLiteral.builder(); } @Override diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index e83cd9956..2175f7e0a 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -295,13 +295,51 @@ public static Expression.NestedStruct nestedStruct(boolean nullable, Expression. return Expression.NestedStruct.builder().nullable(nullable).addFields(fields).build(); } - public static Expression.UserDefinedLiteral userDefinedLiteral( - boolean nullable, String urn, String name, Any value) { - return Expression.UserDefinedLiteral.builder() + /** + * Create a UserDefinedAnyLiteral with google.protobuf.Any representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) + * @param value the value, encoded as google.protobuf.Any + */ + public static Expression.UserDefinedAnyLiteral userDefinedLiteralAny( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + Any value) { + return Expression.UserDefinedAnyLiteral.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .addAllTypeParameters(typeParameters) + .value(value) + .build(); + } + + /** + * Create a UserDefinedStructLiteral with Struct representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) + * @param fields the fields, as a list of Literal values + */ + public static Expression.UserDefinedStructLiteral userDefinedLiteralStruct( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + java.util.List fields) { + return Expression.UserDefinedStructLiteral.builder() .nullable(nullable) .urn(urn) .name(name) - .value(value.toByteString()) + .addAllTypeParameters(typeParameters) + .addAllFields(fields) .build(); } diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index 764f995c2..2de9b17a4 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -64,7 +64,9 @@ public interface ExpressionVisitor { - try { - bldr.setNullable(expr.nullable()) - .setUserDefined( - Expression.Literal.UserDefined.newBuilder() - .setTypeReference(typeReference) - .setValue(Any.parseFrom(expr.value()))) - .build(); - } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException(e); - } + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters( + expr.typeParameters().stream() + .map(typeProtoConverter::toProto) + .collect(java.util.stream.Collectors.toList())) + .setValue(expr.value()); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); + }); + } + + @Override + public Expression visit( + io.substrait.expression.Expression.UserDefinedStructLiteral expr, + EmptyVisitationContext context) { + int typeReference = + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); + return lit( + bldr -> { + Expression.Literal.Struct structLiteral = + Expression.Literal.Struct.newBuilder() + .addAllFields( + expr.fields().stream() + .map(this::toLiteral) + .collect(java.util.stream.Collectors.toList())) + .build(); + + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters( + expr.typeParameters().stream() + .map(typeProtoConverter::toProto) + .collect(java.util.stream.Collectors.toList())) + .setStruct(structLiteral); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); }); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 7f3eb6601..2559338df 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -492,10 +492,36 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { { io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral = literal.getUserDefined(); + SimpleExtension.Type type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); - return ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue()); + String urn = type.urn(); + String name = type.name(); + List typeParameters = + userDefinedLiteral.getTypeParametersList().stream() + .map(protoTypeConverter::from) + .collect(Collectors.toList()); + + switch (userDefinedLiteral.getValCase()) { + case VALUE: + return ExpressionCreator.userDefinedLiteralAny( + literal.getNullable(), urn, name, typeParameters, userDefinedLiteral.getValue()); + case STRUCT: + return ExpressionCreator.userDefinedLiteralStruct( + literal.getNullable(), + urn, + name, + typeParameters, + userDefinedLiteral.getStruct().getFieldsList().stream() + .map(this::from) + .collect(Collectors.toList())); + case VAL_NOT_SET: + throw new IllegalStateException( + "UserDefined literal has no value (neither 'value' nor 'struct' is set)"); + default: + throw new IllegalStateException( + "Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase()); + } } default: throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 89aad954e..31214878c 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -22,6 +22,7 @@ public class DefaultExtensionCatalog { "extension:io.substrait:functions_rounding_decimal"; public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; + public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types"; public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION = loadDefaultCollection(); @@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { .map(c -> String.format("/functions_%s.yaml", c)) .collect(Collectors.toList()); + defaultFiles.add("/extension_types.yaml"); + return SimpleExtension.load(defaultFiles); } } diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index ef0b9879a..ec2df2646 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -214,7 +214,13 @@ public Optional visit(Expression.NestedStruct expr, EmptyVisitationC @Override public Optional visit( - Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E { + Expression.UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws E { + return visitLiteral(expr); + } + + @Override + public Optional visit( + Expression.UserDefinedStructLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index aaf97aa12..c9cb3f780 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -393,6 +393,19 @@ abstract class UserDefined implements Type { public abstract String name(); + /** + * Returns the type parameters for this user-defined type. + * + *

Type parameters are used to represent parameterized/generic types, such as {@code + * vector}. + * + * @return a list of type parameters, or an empty list if this type is not parameterized + */ + @Value.Default + public java.util.List typeParameters() { + return java.util.Collections.emptyList(); + } + public static ImmutableType.UserDefined.Builder builder() { return ImmutableType.UserDefined.builder(); } @@ -402,4 +415,50 @@ public R accept(TypeVisitor typeVisitor) throws E return typeVisitor.visit(this); } } + + /** + * Represents a type parameter for user-defined types. + * + *

Type parameters can be data types (like {@code i32} in {@code List}), or value + * parameters (like the {@code 10} in {@code VARCHAR<10>}). This interface provides a type-safe + * representation of all possible parameter kinds. + */ + interface Parameter {} + + /** A data type parameter, such as the {@code i32} in {@code List}. */ + @Value.Immutable + abstract class ParameterDataType implements Parameter { + public abstract Type type(); + } + + /** A boolean value parameter. */ + @Value.Immutable + abstract class ParameterBooleanValue implements Parameter { + public abstract boolean value(); + } + + /** An integer value parameter, such as the {@code 10} in {@code VARCHAR<10>}. */ + @Value.Immutable + abstract class ParameterIntegerValue implements Parameter { + public abstract long value(); + } + + /** An enum value parameter (represented as a string). */ + @Value.Immutable + abstract class ParameterEnumValue implements Parameter { + public abstract String value(); + } + + /** A string value parameter. */ + @Value.Immutable + abstract class ParameterStringValue implements Parameter { + public abstract String value(); + } + + /** An explicitly null/unspecified parameter, used to select the default value (if any). */ + class ParameterNull implements Parameter { + public static final ParameterNull INSTANCE = new ParameterNull(); + + private ParameterNull() {} + } } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 691d4bce5..67d7bc9b5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -165,6 +165,6 @@ public final T visit(final Type.Map expr) { public final T visit(final Type.UserDefined expr) { int ref = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); - return typeContainer(expr).userDefined(ref); + return typeContainer(expr).userDefined(ref, expr.typeParameters()); } } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 6a1bc3186..57b1f26b5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -131,6 +131,9 @@ public final T struct(T... types) { public abstract T userDefined(int ref); + public abstract T userDefined( + int ref, java.util.List typeParameters); + protected abstract T wrap(Object o); protected abstract I i(int integerValue); diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 4e0caa7c2..817f7b0b3 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -262,6 +262,13 @@ public ParameterizedType userDefined(int ref) { "User defined types are not supported in Parameterized Types for now"); } + @Override + public ParameterizedType userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Parameterized Types for now"); + } + @Override protected ParameterizedType wrap(final Object o) { ParameterizedType.Builder bldr = ParameterizedType.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 95d42328a..bdb600c1c 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -2,6 +2,7 @@ import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; +import io.substrait.type.ImmutableType; import io.substrait.type.Type; import io.substrait.type.TypeCreator; @@ -90,7 +91,16 @@ public Type from(io.substrait.proto.Type type) { { io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); - return n(userDefined.getNullability()).userDefined(t.urn(), t.name()); + boolean nullable = isNullable(userDefined.getNullability()); + return io.substrait.type.Type.UserDefined.builder() + .nullable(nullable) + .urn(t.urn()) + .name(t.name()) + .typeParameters( + userDefined.getTypeParametersList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList())) + .build(); } case USER_DEFINED_TYPE_REFERENCE: throw new UnsupportedOperationException("Unsupported user defined reference: " + type); @@ -118,4 +128,28 @@ private static TypeCreator n(io.substrait.proto.Type.Nullability n) { ? TypeCreator.NULLABLE : TypeCreator.REQUIRED; } + + public io.substrait.type.Type.Parameter from(io.substrait.proto.Type.Parameter parameter) { + switch (parameter.getParameterCase()) { + case NULL: + return io.substrait.type.Type.ParameterNull.INSTANCE; + case DATA_TYPE: + return ImmutableType.ParameterDataType.builder() + .type(from(parameter.getDataType())) + .build(); + case BOOLEAN: + return ImmutableType.ParameterBooleanValue.builder().value(parameter.getBoolean()).build(); + case INTEGER: + return ImmutableType.ParameterIntegerValue.builder().value(parameter.getInteger()).build(); + case ENUM: + return ImmutableType.ParameterEnumValue.builder().value(parameter.getEnum()).build(); + case STRING: + return ImmutableType.ParameterStringValue.builder().value(parameter.getString()).build(); + case PARAMETER_NOT_SET: + throw new IllegalArgumentException("Parameter type is not set: " + parameter); + default: + throw new UnsupportedOperationException( + "Unsupported parameter type: " + parameter.getParameterCase()); + } + } } diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index 96cddd395..f9d2129d2 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -355,6 +355,13 @@ public DerivationExpression userDefined(int ref) { "User defined types are not supported in Derivation Expressions for now"); } + @Override + public DerivationExpression userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Derivation Expressions for now"); + } + @Override protected DerivationExpression wrap(final Object o) { DerivationExpression.Builder bldr = DerivationExpression.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 2d0ed0ffc..6422904c4 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -5,25 +5,68 @@ /** Convert from {@link io.substrait.type.Type} to {@link io.substrait.proto.Type} */ public class TypeProtoConverter extends BaseProtoConverter { - private static final BaseProtoTypes NULLABLE = - new Types(Type.Nullability.NULLABILITY_NULLABLE); - private static final BaseProtoTypes REQUIRED = - new Types(Type.Nullability.NULLABILITY_REQUIRED); + // Instance fields (not static) because Types is a non-static inner class that calls + // TypeProtoConverter.this.toProto() to recursively convert nested type parameters. + // Each converter instance needs its own Types instances to ensure type registrations + // use the correct ExtensionCollector. + private final BaseProtoTypes NULLABLE; + private final BaseProtoTypes REQUIRED; public TypeProtoConverter(ExtensionCollector extensionCollector) { super(extensionCollector, "Type literals cannot contain parameters or expressions."); + NULLABLE = new Types(Type.Nullability.NULLABILITY_NULLABLE); + REQUIRED = new Types(Type.Nullability.NULLABILITY_REQUIRED); } public io.substrait.proto.Type toProto(io.substrait.type.Type type) { return type.accept(this); } + public io.substrait.proto.Type.Parameter toProto(io.substrait.type.Type.Parameter parameter) { + if (parameter instanceof io.substrait.type.Type.ParameterNull) { + return Type.Parameter.newBuilder() + .setNull(com.google.protobuf.Empty.getDefaultInstance()) + .build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterDataType) { + io.substrait.type.Type.ParameterDataType dataType = + (io.substrait.type.Type.ParameterDataType) parameter; + return Type.Parameter.newBuilder().setDataType(toProto(dataType.type())).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterBooleanValue) { + io.substrait.type.Type.ParameterBooleanValue boolValue = + (io.substrait.type.Type.ParameterBooleanValue) parameter; + return Type.Parameter.newBuilder().setBoolean(boolValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterIntegerValue) { + io.substrait.type.Type.ParameterIntegerValue intValue = + (io.substrait.type.Type.ParameterIntegerValue) parameter; + return Type.Parameter.newBuilder().setInteger(intValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterEnumValue) { + io.substrait.type.Type.ParameterEnumValue enumValue = + (io.substrait.type.Type.ParameterEnumValue) parameter; + return Type.Parameter.newBuilder().setEnum(enumValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterStringValue) { + io.substrait.type.Type.ParameterStringValue stringValue = + (io.substrait.type.Type.ParameterStringValue) parameter; + return Type.Parameter.newBuilder().setString(stringValue.value()).build(); + } else { + throw new UnsupportedOperationException( + "Unsupported parameter type: " + parameter.getClass()); + } + } + @Override public BaseProtoTypes typeContainer(final boolean nullable) { return nullable ? NULLABLE : REQUIRED; } - private static class Types extends BaseProtoTypes { + /** + * Non-static inner class that can access the outer TypeProtoConverter instance. + * + *

This class must be non-static to access TypeProtoConverter.this.toProto() for converting + * nested type parameters (e.g., ParameterDataType containing another Type). Being non-static + * means instances are bound to a specific outer TypeProtoConverter instance, ensuring parameter + * conversions use the correct ExtensionCollector. + */ + private class Types extends BaseProtoTypes { public Types(final Type.Nullability nullability) { super(nullability); @@ -133,6 +176,20 @@ public Type userDefined(int ref) { Type.UserDefined.newBuilder().setTypeReference(ref).setNullability(nullability).build()); } + @Override + public Type userDefined( + int ref, java.util.List typeParameters) { + return wrap( + Type.UserDefined.newBuilder() + .setTypeReference(ref) + .setNullability(nullability) + .addAllTypeParameters( + typeParameters.stream() + .map(TypeProtoConverter.this::toProto) + .collect(java.util.stream.Collectors.toList())) + .build()); + } + @Override protected Type wrap(final Object o) { Type.Builder bldr = Type.newBuilder(); diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index 3defbf78f..e2fa7a91c 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -1,8 +1,12 @@ package io.substrait; +import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; @@ -25,9 +29,22 @@ public abstract class TestBase { protected ProtoRelConverter protoRelConverter = new ProtoRelConverter(functionCollector, defaultExtensionCollection); + protected ExpressionProtoConverter expressionProtoConverter = + relProtoConverter.getExpressionProtoConverter(); + + protected ProtoExpressionConverter protoExpressionConverter = + new ProtoExpressionConverter( + functionCollector, defaultExtensionCollection, EMPTY_TYPE, protoRelConverter); + protected void verifyRoundTrip(Rel rel) { io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); Rel relReturned = protoRelConverter.from(protoRel); assertEquals(rel, relReturned); } + + protected void verifyRoundTrip(Expression expression) { + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(expression); + Expression expressionReturned = protoExpressionConverter.from(protoExpression); + assertEquals(expression, expressionReturned); + } } diff --git a/core/src/test/java/io/substrait/plan/PlanConverterTest.java b/core/src/test/java/io/substrait/plan/PlanConverterTest.java index dd49cf207..c56fec83c 100644 --- a/core/src/test/java/io/substrait/plan/PlanConverterTest.java +++ b/core/src/test/java/io/substrait/plan/PlanConverterTest.java @@ -3,14 +3,22 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; import io.substrait.extension.AdvancedExtension; +import io.substrait.extension.SimpleExtension; import io.substrait.plan.Plan.Root; import io.substrait.relation.EmptyScan; +import io.substrait.relation.ImmutableVirtualTableScan; +import io.substrait.relation.VirtualTableScan; import io.substrait.type.NamedStruct; +import io.substrait.type.Type; import io.substrait.type.TypeCreator; import io.substrait.utils.StringHolder; import io.substrait.utils.StringHolderHandlingExtensionProtoConverter; import io.substrait.utils.StringHolderHandlingProtoExtensionConverter; +import java.util.Arrays; +import java.util.Collections; import org.junit.jupiter.api.Test; class PlanConverterTest { @@ -189,4 +197,127 @@ void planIncludingRelationWithAdvancedExtension() { assertEquals(plan, plan2); } + + /** + * Verifies that nested UserDefined types with type parameters share the same ExtensionCollector + * and don't create duplicate type references. Tests that a plan containing both a standalone + * UserDefined literal (point) and a parameterized UserDefined literal (vector) correctly + * registers both types in the extension collection without duplication. + */ + @Test + void nestedUserDefinedTypesShareExtensionCollector() { + // Define custom types: point and vector + String urn = "extension:test:nested_types"; + String yaml = + "---\n" + + "urn: " + + urn + + "\n" + + "types:\n" + + " - name: point\n" + + " structure:\n" + + " x: i32\n" + + " y: i32\n" + + " - name: vector\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " structure:\n" + + " x: T\n" + + " y: T\n" + + " z: T\n"; + + SimpleExtension.ExtensionCollection extensions = SimpleExtension.load("test.yaml", yaml); + + // Create type objects + Type pointType = Type.UserDefined.builder().nullable(false).urn(urn).name("point").build(); + + Type.Parameter pointTypeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder().type(pointType).build(); + + Type vectorOfPointType = + Type.UserDefined.builder() + .nullable(false) + .urn(urn) + .name("vector") + .addTypeParameters(pointTypeParam) + .build(); + + // Create literals + Expression.UserDefinedStructLiteral pointLiteral = + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList(ExpressionCreator.i32(false, 10), ExpressionCreator.i32(false, 20))); + + // Create vector literal: vector{(1,2), (3,4), (5,6)} + Expression.UserDefinedStructLiteral vectorOfPointLiteral = + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "vector", + Arrays.asList(pointTypeParam), + Arrays.asList( + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 1), ExpressionCreator.i32(false, 2))), + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 3), ExpressionCreator.i32(false, 4))), + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 5), ExpressionCreator.i32(false, 6))))); + + Type nullablePointType = + Type.UserDefined.builder().nullable(true).urn(urn).name("point").build(); + + Expression.UserDefinedStructLiteral nullablePointLiteral = + ExpressionCreator.userDefinedLiteralStruct( + true, + urn, + "point", + Collections.emptyList(), + Arrays.asList(ExpressionCreator.i32(false, 30), ExpressionCreator.i32(false, 40))); + + // Create virtual table with all three columns (nullable point, required point, required vector) + VirtualTableScan virtualTable = + ImmutableVirtualTableScan.builder() + .initialSchema( + NamedStruct.of( + Arrays.asList("nullable_point_col", "point_col", "vector_col"), + TypeCreator.REQUIRED.struct(nullablePointType, pointType, vectorOfPointType))) + .addRows( + ExpressionCreator.nestedStruct( + false, nullablePointLiteral, pointLiteral, vectorOfPointLiteral)) + .build(); + + Plan plan = Plan.builder().addRoots(Root.builder().input(virtualTable).build()).build(); + + PlanProtoConverter toProtoConverter = new PlanProtoConverter(); + io.substrait.proto.Plan protoPlan = toProtoConverter.toProto(plan); + + assertEquals(1, protoPlan.getExtensionUrnsCount(), "Should have exactly 1 extension URN"); + assertEquals( + 2, + protoPlan.getExtensionsCount(), + "Should have exactly 2 type extensions (point and vector), no duplicates"); + + ProtoPlanConverter fromProtoConverter = new ProtoPlanConverter(extensions); + Plan roundTrippedPlan = fromProtoConverter.from(protoPlan); + assertEquals(plan, roundTrippedPlan, "Plan should roundtrip correctly"); + } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index bfb417365..69516c29f 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -3,23 +3,288 @@ import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.protobuf.Any; import io.substrait.TestBase; +import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.expression.proto.ProtoExpressionConverter; -import io.substrait.util.EmptyVisitationContext; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.relation.RelProtoConverter; import java.math.BigDecimal; +import java.util.Collections; import org.junit.jupiter.api.Test; class LiteralRoundtripTest extends TestBase { + private static final String NESTED_TYPES_URN = "extension:io.substrait:test_nested_types"; + + private static final String NESTED_TYPES_YAML = + "---\n" + + "urn: " + + NESTED_TYPES_URN + + "\n" + + "types:\n" + + " - name: point\n" + + " structure:\n" + + " latitude: i32\n" + + " longitude: i32\n" + + " - name: triangle\n" + + " structure:\n" + + " p1: point\n" + + " p2: point\n" + + " p3: point\n" + + " - name: vector\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " structure:\n" + + " x: T\n" + + " y: T\n" + + " z: T\n" + + " - name: multi_param\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " - name: size\n" + + " type: integer\n" + + " - name: nullable\n" + + " type: boolean\n" + + " - name: encoding\n" + + " type: string\n" + + " - name: precision\n" + + " type: dataType\n" + + " - name: mode\n" + + " type: enum\n" + + " structure:\n" + + " value: T\n"; + + private static final SimpleExtension.ExtensionCollection NESTED_TYPES_EXTENSIONS = + SimpleExtension.load("nested_types.yaml", NESTED_TYPES_YAML); + + private static final ExtensionCollector NESTED_TYPES_FUNCTION_COLLECTOR = + new ExtensionCollector(); + private static final RelProtoConverter NESTED_TYPES_REL_PROTO_CONVERTER = + new RelProtoConverter(NESTED_TYPES_FUNCTION_COLLECTOR); + private static final ProtoRelConverter NESTED_TYPES_PROTO_REL_CONVERTER = + new ProtoRelConverter(NESTED_TYPES_FUNCTION_COLLECTOR, NESTED_TYPES_EXTENSIONS); + private static final ExpressionProtoConverter NESTED_TYPES_EXPRESSION_TO_PROTO = + new ExpressionProtoConverter( + NESTED_TYPES_FUNCTION_COLLECTOR, NESTED_TYPES_REL_PROTO_CONVERTER); + private static final ProtoExpressionConverter NESTED_TYPES_PROTO_TO_EXPRESSION = + new ProtoExpressionConverter( + NESTED_TYPES_FUNCTION_COLLECTOR, + NESTED_TYPES_EXTENSIONS, + EMPTY_TYPE, + NESTED_TYPES_PROTO_REL_CONVERTER); + + private void verifyNestedTypesRoundTrip(Expression expression) { + io.substrait.proto.Expression protoExpression = + NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(expression); + Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); + assertEquals(expression, result); + } + @Test void decimal() { io.substrait.expression.Expression.DecimalLiteral val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); - ProtoExpressionConverter from = - new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); + verifyRoundTrip(val); + } + + /** Verifies round-trip conversion of a simple user-defined type using Any representation. */ + @Test + void userDefinedLiteralWithAnyRepresentation() { + Any anyValue = + Any.pack(com.google.protobuf.StringValue.of("")); + + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue); + + verifyRoundTrip(val); + } + + /** Verifies round-trip conversion of a simple user-defined type using Struct representation. */ + @Test + void userDefinedLiteralWithStructRepresentation() { + java.util.List fields = + java.util.Arrays.asList( + ExpressionCreator.i32(false, 42), ExpressionCreator.i32(false, 100)); + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralStruct( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + fields); + + verifyRoundTrip(val); + } + + /** + * Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three + * point UDTs. Both outer and nested types use Struct representation. + */ + @Test + void nestedUserDefinedLiteralWithStructRepresentation() { + Expression.UserDefinedStructLiteral p1 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 0), ExpressionCreator.i32(false, 0))); + + Expression.UserDefinedStructLiteral p2 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 10), ExpressionCreator.i32(false, 0))); + + Expression.UserDefinedStructLiteral p3 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 5), ExpressionCreator.i32(false, 10))); + + Expression.UserDefinedStructLiteral triangle = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "triangle", + Collections.emptyList(), + java.util.Arrays.asList(p1, p2, p3)); + + verifyNestedTypesRoundTrip(triangle); + } + + /** + * Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three + * point UDTs. Both outer and nested types use Any representation. + */ + @Test + void nestedUserDefinedLiteralWithAnyRepresentation() { + Any triangleAny = + Any.pack(com.google.protobuf.StringValue.of("")); + + Expression.UserDefinedAnyLiteral triangle = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "triangle", Collections.emptyList(), triangleAny); + + verifyNestedTypesRoundTrip(triangle); + } + + /** + * Verifies round-trip conversion of nested user-defined types with mixed representations. The + * triangle UDT uses Struct representation while the nested point UDTs use Any representation. + */ + @Test + void mixedRepresentationNestedUserDefinedLiteral() { + Any anyValue = + Any.pack(com.google.protobuf.StringValue.of("")); + + // Create point UDTs using Any representation + Expression.UserDefinedAnyLiteral p1 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); + + Expression.UserDefinedAnyLiteral p2 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); + + Expression.UserDefinedAnyLiteral p3 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); + + // Create a "triangle" UDT using Struct representation, but with Any-encoded point fields + Expression.UserDefinedStructLiteral triangle = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "triangle", + Collections.emptyList(), + java.util.Arrays.asList(p1, p2, p3)); + + verifyNestedTypesRoundTrip(triangle); + } + + /** + * Verifies round-trip conversion of a parameterized user-defined type. Tests that type parameters + * are correctly preserved during serialization and deserialization. + */ + @Test + void userDefinedLiteralWithTypeParameters() { + // Create a type parameter for i32 + io.substrait.type.Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + + // Create a vector instance with fields (x: 1, y: 2, z: 3) + Expression.UserDefinedStructLiteral vectorI32 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "vector", + java.util.Arrays.asList(typeParam), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 1), + ExpressionCreator.i32(false, 2), + ExpressionCreator.i32(false, 3))); + + verifyNestedTypesRoundTrip(vectorI32); + } + + /** + * Verifies round-trip conversion of a user-defined type with all parameter types. Tests that all + * parameter kinds (type, integer, boolean, string, null, enum) are correctly preserved during + * serialization and deserialization. + */ + @Test + void userDefinedLiteralWithAllParameterTypes() { + io.substrait.type.Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + + io.substrait.type.Type.Parameter intParam = + io.substrait.type.ImmutableType.ParameterIntegerValue.builder().value(100L).build(); + + io.substrait.type.Type.Parameter boolParam = + io.substrait.type.ImmutableType.ParameterBooleanValue.builder().value(true).build(); + + io.substrait.type.Type.Parameter stringParam = + io.substrait.type.ImmutableType.ParameterStringValue.builder().value("utf8").build(); + + io.substrait.type.Type.Parameter nullParam = io.substrait.type.Type.ParameterNull.INSTANCE; + + io.substrait.type.Type.Parameter enumParam = + io.substrait.type.ImmutableType.ParameterEnumValue.builder().value("FAST").build(); + + Expression.UserDefinedStructLiteral multiParam = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "multi_param", + java.util.Arrays.asList( + typeParam, intParam, boolParam, stringParam, nullParam, enumParam), + java.util.Arrays.asList(ExpressionCreator.i32(false, 42))); + + verifyNestedTypesRoundTrip(multiParam); } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index 838feb508..2a857a2b6 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -38,7 +38,8 @@ import io.substrait.expression.Expression.TimestampLiteral; import io.substrait.expression.Expression.TimestampTZLiteral; import io.substrait.expression.Expression.UUIDLiteral; -import io.substrait.expression.Expression.UserDefinedLiteral; +import io.substrait.expression.Expression.UserDefinedAnyLiteral; +import io.substrait.expression.Expression.UserDefinedStructLiteral; import io.substrait.expression.Expression.VarCharLiteral; import io.substrait.expression.Expression.WindowFunctionInvocation; import io.substrait.expression.ExpressionVisitor; @@ -195,9 +196,15 @@ public String visit(Expression.NestedStruct expr, EmptyVisitationContext context } @Override - public String visit(UserDefinedLiteral expr, EmptyVisitationContext context) + public String visit(UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws RuntimeException { - return ""; + return ""; + } + + @Override + public String visit(UserDefinedStructLiteral expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; } @Override diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java index a94d47fd4..f932b8044 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java @@ -96,7 +96,7 @@ public static List explain(io.substrait.plan.Plan plan) { /** * Explains the Sustrait relation * - * @param plan Subsrait relation + * @param rel Subsrait relation * @return List of strings; typically these would then be logged or sent to stdout */ public static List explain(io.substrait.relation.Rel rel) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 3406de7de..72198d49c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -44,15 +44,16 @@ public class CallConverters { * {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent and store {@link * Expression.UserDefinedLiteral}s within Calcite. * - *

When converting from Substrait to Calcite, the {@link Expression.UserDefinedLiteral#value()} - * is stored within a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link - * org.apache.calcite.rex.RexLiteral} and then re-interpreted to have the correct type. + *

When converting from Substrait to Calcite, the {@link + * Expression.UserDefinedAnyLiteral#value()} is stored within a {@link + * org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link org.apache.calcite.rex.RexLiteral} and + * then re-interpreted to have the correct type. * - *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedLiteral, + *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedAnyLiteral, * SubstraitRelNodeConverter.Context)} for this conversion. * *

When converting from Calcite to Substrait, this call converter extracts the {@link - * Expression.UserDefinedLiteral} that was stored. + * Expression.UserDefinedAnyLiteral} that was stored. */ public static Function REINTERPRET = typeConverter -> @@ -70,11 +71,21 @@ public class CallConverters { Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand; Type.UserDefined t = (Type.UserDefined) type; - return Expression.UserDefinedLiteral.builder() - .urn(t.urn()) - .name(t.name()) - .value(literal.value()) - .build(); + // The binary literal contains the serialized protobuf Any - just parse it directly + try { + com.google.protobuf.Any anyValue = + com.google.protobuf.Any.parseFrom(literal.value().toByteArray()); + + return Expression.UserDefinedAnyLiteral.builder() + .nullable(t.nullable()) + .urn(t.urn()) + .name(t.name()) + .addAllTypeParameters(t.typeParameters()) + .value(anyValue) + .build(); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw new IllegalStateException("Failed to parse UserDefinedAnyLiteral value", e); + } } return null; }; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 2b8052889..68936a6d4 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -109,7 +109,7 @@ public RexNode visit(Expression.NullLiteral expr, Context context) throws Runtim } @Override - public RexNode visit(Expression.UserDefinedLiteral expr, Context context) + public RexNode visit(Expression.UserDefinedAnyLiteral expr, Context context) throws RuntimeException { RexLiteral binaryLiteral = rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteArray())); @@ -117,6 +117,13 @@ public RexNode visit(Expression.UserDefinedLiteral expr, Context context) return rexBuilder.makeReinterpretCast(type, binaryLiteral, rexBuilder.makeLiteral(false)); } + @Override + public RexNode visit(Expression.UserDefinedStructLiteral expr, Context context) + throws RuntimeException { + throw new UnsupportedOperationException( + "UserDefinedStructLiteral representation is not yet supported in Isthmus"); + } + @Override public RexNode visit(Expression.BoolLiteral expr, Context context) throws RuntimeException { return rexBuilder.makeLiteral(expr.value()); diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 34e06d0ac..98aa04203 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -7,7 +7,6 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression.UserDefinedLiteral; import io.substrait.expression.ExpressionCreator; -import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.FunctionMappings; @@ -16,9 +15,7 @@ import io.substrait.isthmus.utils.UserTypeFactory; import io.substrait.proto.Expression; import io.substrait.proto.Expression.Literal.Builder; -import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; -import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.io.IOException; @@ -587,7 +584,9 @@ void customTypesInFunctionsRoundtrip() { void customTypesLiteralInFunctionsRoundtrip() { Builder bldr = Expression.Literal.newBuilder(); Any anyValue = Any.pack(bldr.setI32(10).build()); - UserDefinedLiteral val = ExpressionCreator.userDefinedLiteral(false, URN, "a_type", anyValue); + UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue); Rel rel1 = b.project( @@ -599,10 +598,24 @@ void customTypesLiteralInFunctionsRoundtrip() { RelNode calciteRel = substraitToCalcite.convert(rel1); Rel rel2 = calciteToSubstrait.apply(calciteRel); assertEquals(rel1, rel2); + } - ExtensionCollector extensionCollector = new ExtensionCollector(); - io.substrait.proto.Rel protoRel = new RelProtoConverter(extensionCollector).toProto(rel1); - Rel rel3 = new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); - assertEquals(rel1, rel3); + @Test + void customNullableUserDefinedLiteralRoundtrip() { + Builder bldr = Expression.Literal.newBuilder(); + Any anyValue = Any.pack(bldr.setI32(10).build()); + UserDefinedLiteral nullableLiteral = + ExpressionCreator.userDefinedLiteralAny( + true, URN, "a_type", java.util.Collections.emptyList(), anyValue); + + Rel rel = + b.project( + input -> List.of(nullableLiteral), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + Rel relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); } } diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 5377f4257..10c134658 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -77,7 +77,13 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { } override def visit( - expr: Expression.UserDefinedLiteral, + expr: Expression.UserDefinedAnyLiteral, + context: EmptyVisitationContext): String = { + expr.toString + } + + override def visit( + expr: Expression.UserDefinedStructLiteral, context: EmptyVisitationContext): String = { expr.toString } diff --git a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala index 5f7137b14..d1b7a32a6 100644 --- a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala +++ b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala @@ -65,9 +65,11 @@ class DefaultExpressionVisitor[T] context: EmptyVisitationContext): T = e.accept(this, context) + override def visit(expr: Expression.UserDefinedAnyLiteral, context: EmptyVisitationContext): T = + visitFallback(expr, context) + override def visit( - userDefinedLiteral: Expression.UserDefinedLiteral, - context: EmptyVisitationContext): T = { - visitFallback(userDefinedLiteral, context) - } + expr: Expression.UserDefinedStructLiteral, + context: EmptyVisitationContext): T = + visitFallback(expr, context) }