diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 072507295..6b49cae19 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E { return visitFallback(expr, context); } + @Override + public O visit(Expression.UserDefinedAnyLiteral expr, C context) throws E { + return visitFallback(expr, context); + } + + @Override + public O visit(Expression.UserDefinedStructLiteral expr, C context) throws E { + return visitFallback(expr, context); + } + @Override public O visit(Expression.Switch expr, C context) throws E { return visitFallback(expr, context); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 42c3c5118..7f0fc8748 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -662,21 +662,96 @@ public R accept( } } + /** + * Base interface for user-defined literals. + * + *

User-defined literals can be encoded in one of two ways as per the Substrait spec: + * + *

+ * + * @see UserDefinedAnyLiteral + * @see UserDefinedStructLiteral + */ + interface UserDefinedLiteral extends Literal { + String urn(); + + String name(); + + List typeParameters(); + } + + /** + * User-defined literal with value encoded as {@code google.protobuf.Any}. + * + *

This encoding allows for arbitrary binary data to be stored in the literal value. + */ @Value.Immutable - abstract class UserDefinedLiteral implements Literal { - public abstract ByteString value(); + abstract class UserDefinedAnyLiteral implements UserDefinedLiteral { + @Override + public abstract String urn(); + + @Override + public abstract String name(); + + @Override + public abstract List typeParameters(); + + public abstract com.google.protobuf.Any value(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); + } + + public static ImmutableExpression.UserDefinedAnyLiteral.Builder builder() { + return ImmutableExpression.UserDefinedAnyLiteral.builder(); + } + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + + /** + * User-defined literal with value encoded as {@code Literal.Struct}. + * + *

This encoding uses a structured list of fields to represent the literal value. + */ + @Value.Immutable + abstract class UserDefinedStructLiteral implements UserDefinedLiteral { + @Override public abstract String urn(); + @Override public abstract String name(); @Override - public Type getType() { - return Type.withNullability(nullable()).userDefined(urn(), name()); + public abstract List typeParameters(); + + public abstract List fields(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); } - public static ImmutableExpression.UserDefinedLiteral.Builder builder() { - return ImmutableExpression.UserDefinedLiteral.builder(); + public static ImmutableExpression.UserDefinedStructLiteral.Builder builder() { + return ImmutableExpression.UserDefinedStructLiteral.builder(); } @Override diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index adf157d7b..0b3fb6cd8 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -286,13 +286,51 @@ public static Expression.StructLiteral struct( return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build(); } - public static Expression.UserDefinedLiteral userDefinedLiteral( - boolean nullable, String urn, String name, Any value) { - return Expression.UserDefinedLiteral.builder() + /** + * Create a UserDefinedAnyLiteral with google.protobuf.Any representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) + * @param value the value, encoded as google.protobuf.Any + */ + public static Expression.UserDefinedAnyLiteral userDefinedLiteralAny( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + Any value) { + return Expression.UserDefinedAnyLiteral.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .addAllTypeParameters(typeParameters) + .value(value) + .build(); + } + + /** + * Create a UserDefinedStructLiteral with Struct representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) + * @param fields the fields, as a list of Literal values + */ + public static Expression.UserDefinedStructLiteral userDefinedLiteralStruct( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + java.util.List fields) { + return Expression.UserDefinedStructLiteral.builder() .nullable(nullable) .urn(urn) .name(name) - .value(value.toByteString()) + .addAllTypeParameters(typeParameters) + .addAllFields(fields) .build(); } diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index d64cab48c..43f54cadf 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -62,7 +62,9 @@ public interface ExpressionVisitor { - try { - bldr.setNullable(expr.nullable()) - .setUserDefined( - Expression.Literal.UserDefined.newBuilder() - .setTypeReference(typeReference) - .setValue(Any.parseFrom(expr.value()))) - .build(); - } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException(e); - } + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters( + expr.typeParameters().stream() + .map(typeProtoConverter::toProto) + .collect(java.util.stream.Collectors.toList())) + .setValue(expr.value()); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); + }); + } + + @Override + public Expression visit( + io.substrait.expression.Expression.UserDefinedStructLiteral expr, + EmptyVisitationContext context) { + int typeReference = + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); + return lit( + bldr -> { + Expression.Literal.Struct structLiteral = + Expression.Literal.Struct.newBuilder() + .addAllFields( + expr.fields().stream() + .map(this::toLiteral) + .collect(java.util.stream.Collectors.toList())) + .build(); + + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters( + expr.typeParameters().stream() + .map(typeProtoConverter::toProto) + .collect(java.util.stream.Collectors.toList())) + .setStruct(structLiteral); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); }); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 8f95cdf07..39c5e2f9a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -492,10 +492,40 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { { io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral = literal.getUserDefined(); + SimpleExtension.Type type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); - return ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue()); + String urn = type.urn(); + String name = type.name(); + + switch (userDefinedLiteral.getValCase()) { + case VALUE: + return ExpressionCreator.userDefinedLiteralAny( + literal.getNullable(), + urn, + name, + userDefinedLiteral.getTypeParametersList().stream() + .map(protoTypeConverter::from) + .collect(Collectors.toList()), + userDefinedLiteral.getValue()); + case STRUCT: + return ExpressionCreator.userDefinedLiteralStruct( + literal.getNullable(), + urn, + name, + userDefinedLiteral.getTypeParametersList().stream() + .map(protoTypeConverter::from) + .collect(Collectors.toList()), + userDefinedLiteral.getStruct().getFieldsList().stream() + .map(this::from) + .collect(Collectors.toList())); + case VAL_NOT_SET: + throw new IllegalStateException( + "UserDefined literal has no value (neither 'value' nor 'struct' is set)"); + default: + throw new IllegalStateException( + "Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase()); + } } default: throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 89aad954e..31214878c 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -22,6 +22,7 @@ public class DefaultExtensionCatalog { "extension:io.substrait:functions_rounding_decimal"; public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; + public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types"; public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION = loadDefaultCollection(); @@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { .map(c -> String.format("/functions_%s.yaml", c)) .collect(Collectors.toList()); + defaultFiles.add("/extension_types.yaml"); + return SimpleExtension.load(defaultFiles); } } diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 57132a940..6c961e4d3 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -205,7 +205,13 @@ public Optional visit(Expression.StructLiteral expr, EmptyVisitation @Override public Optional visit( - Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E { + Expression.UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws E { + return visitLiteral(expr); + } + + @Override + public Optional visit( + Expression.UserDefinedStructLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index aaf97aa12..c597e7e67 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -393,6 +393,30 @@ abstract class UserDefined implements Type { public abstract String name(); + /** + * Returns the type parameters for this user-defined type. + * + *

Type parameters are used to represent parameterized/generic types, such as {@code + * vector} or custom types like {@code FixedArray<100>}. Each parameter in the list can be + * either a type (like {@code i32}) or a value (like the integer {@code 100}). + * + *

Unlike built-in parameterized types ({@link Map}, {@link ListType}, {@link Decimal}), + * which have fixed, known schemas with concrete typed fields, user-defined types have variable, + * unknown schemas. This is why UserDefined uses a generic {@link Parameter} list that can hold + * any mix of types or values, while other parameterized types use concrete fields like {@code + * Type key()} or {@code int precision()}. + * + *

For example, a user-defined {@code vector} type parameterized by {@code i32} would have + * one type parameter containing the {@code i32} type definition, while a {@code FixedArray} + * type might take an integer parameter specifying its size. + * + * @return a list of type parameters, or an empty list if this type is not parameterized + */ + @Value.Default + public java.util.List typeParameters() { + return java.util.Collections.emptyList(); + } + public static ImmutableType.UserDefined.Builder builder() { return ImmutableType.UserDefined.builder(); } @@ -402,4 +426,50 @@ public R accept(TypeVisitor typeVisitor) throws E return typeVisitor.visit(this); } } + + /** + * Represents a type parameter for user-defined types. + * + *

Type parameters can be data types (like {@code i32} in {@code List}), or value + * parameters (like the {@code 10} in {@code VARCHAR<10>}). This interface provides a type-safe + * representation of all possible parameter kinds. + */ + interface Parameter {} + + /** A data type parameter, such as the {@code i32} in {@code List}. */ + @Value.Immutable + abstract class ParameterDataType implements Parameter { + public abstract Type type(); + } + + /** A boolean value parameter. */ + @Value.Immutable + abstract class ParameterBooleanValue implements Parameter { + public abstract boolean value(); + } + + /** An integer value parameter, such as the {@code 10} in {@code VARCHAR<10>}. */ + @Value.Immutable + abstract class ParameterIntegerValue implements Parameter { + public abstract long value(); + } + + /** An enum value parameter (represented as a string). */ + @Value.Immutable + abstract class ParameterEnumValue implements Parameter { + public abstract String value(); + } + + /** A string value parameter. */ + @Value.Immutable + abstract class ParameterStringValue implements Parameter { + public abstract String value(); + } + + /** An explicitly null/unspecified parameter, used to select the default value (if any). */ + class ParameterNull implements Parameter { + public static final ParameterNull INSTANCE = new ParameterNull(); + + private ParameterNull() {} + } } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 691d4bce5..67d7bc9b5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -165,6 +165,6 @@ public final T visit(final Type.Map expr) { public final T visit(final Type.UserDefined expr) { int ref = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); - return typeContainer(expr).userDefined(ref); + return typeContainer(expr).userDefined(ref, expr.typeParameters()); } } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 6a1bc3186..57b1f26b5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -131,6 +131,9 @@ public final T struct(T... types) { public abstract T userDefined(int ref); + public abstract T userDefined( + int ref, java.util.List typeParameters); + protected abstract T wrap(Object o); protected abstract I i(int integerValue); diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 4e0caa7c2..817f7b0b3 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -262,6 +262,13 @@ public ParameterizedType userDefined(int ref) { "User defined types are not supported in Parameterized Types for now"); } + @Override + public ParameterizedType userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Parameterized Types for now"); + } + @Override protected ParameterizedType wrap(final Object o) { ParameterizedType.Builder bldr = ParameterizedType.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 95d42328a..bdb600c1c 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -2,6 +2,7 @@ import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; +import io.substrait.type.ImmutableType; import io.substrait.type.Type; import io.substrait.type.TypeCreator; @@ -90,7 +91,16 @@ public Type from(io.substrait.proto.Type type) { { io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); - return n(userDefined.getNullability()).userDefined(t.urn(), t.name()); + boolean nullable = isNullable(userDefined.getNullability()); + return io.substrait.type.Type.UserDefined.builder() + .nullable(nullable) + .urn(t.urn()) + .name(t.name()) + .typeParameters( + userDefined.getTypeParametersList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList())) + .build(); } case USER_DEFINED_TYPE_REFERENCE: throw new UnsupportedOperationException("Unsupported user defined reference: " + type); @@ -118,4 +128,28 @@ private static TypeCreator n(io.substrait.proto.Type.Nullability n) { ? TypeCreator.NULLABLE : TypeCreator.REQUIRED; } + + public io.substrait.type.Type.Parameter from(io.substrait.proto.Type.Parameter parameter) { + switch (parameter.getParameterCase()) { + case NULL: + return io.substrait.type.Type.ParameterNull.INSTANCE; + case DATA_TYPE: + return ImmutableType.ParameterDataType.builder() + .type(from(parameter.getDataType())) + .build(); + case BOOLEAN: + return ImmutableType.ParameterBooleanValue.builder().value(parameter.getBoolean()).build(); + case INTEGER: + return ImmutableType.ParameterIntegerValue.builder().value(parameter.getInteger()).build(); + case ENUM: + return ImmutableType.ParameterEnumValue.builder().value(parameter.getEnum()).build(); + case STRING: + return ImmutableType.ParameterStringValue.builder().value(parameter.getString()).build(); + case PARAMETER_NOT_SET: + throw new IllegalArgumentException("Parameter type is not set: " + parameter); + default: + throw new UnsupportedOperationException( + "Unsupported parameter type: " + parameter.getParameterCase()); + } + } } diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index 96cddd395..f9d2129d2 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -355,6 +355,13 @@ public DerivationExpression userDefined(int ref) { "User defined types are not supported in Derivation Expressions for now"); } + @Override + public DerivationExpression userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Derivation Expressions for now"); + } + @Override protected DerivationExpression wrap(final Object o) { DerivationExpression.Builder bldr = DerivationExpression.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 2d0ed0ffc..6422904c4 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -5,25 +5,68 @@ /** Convert from {@link io.substrait.type.Type} to {@link io.substrait.proto.Type} */ public class TypeProtoConverter extends BaseProtoConverter { - private static final BaseProtoTypes NULLABLE = - new Types(Type.Nullability.NULLABILITY_NULLABLE); - private static final BaseProtoTypes REQUIRED = - new Types(Type.Nullability.NULLABILITY_REQUIRED); + // Instance fields (not static) because Types is a non-static inner class that calls + // TypeProtoConverter.this.toProto() to recursively convert nested type parameters. + // Each converter instance needs its own Types instances to ensure type registrations + // use the correct ExtensionCollector. + private final BaseProtoTypes NULLABLE; + private final BaseProtoTypes REQUIRED; public TypeProtoConverter(ExtensionCollector extensionCollector) { super(extensionCollector, "Type literals cannot contain parameters or expressions."); + NULLABLE = new Types(Type.Nullability.NULLABILITY_NULLABLE); + REQUIRED = new Types(Type.Nullability.NULLABILITY_REQUIRED); } public io.substrait.proto.Type toProto(io.substrait.type.Type type) { return type.accept(this); } + public io.substrait.proto.Type.Parameter toProto(io.substrait.type.Type.Parameter parameter) { + if (parameter instanceof io.substrait.type.Type.ParameterNull) { + return Type.Parameter.newBuilder() + .setNull(com.google.protobuf.Empty.getDefaultInstance()) + .build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterDataType) { + io.substrait.type.Type.ParameterDataType dataType = + (io.substrait.type.Type.ParameterDataType) parameter; + return Type.Parameter.newBuilder().setDataType(toProto(dataType.type())).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterBooleanValue) { + io.substrait.type.Type.ParameterBooleanValue boolValue = + (io.substrait.type.Type.ParameterBooleanValue) parameter; + return Type.Parameter.newBuilder().setBoolean(boolValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterIntegerValue) { + io.substrait.type.Type.ParameterIntegerValue intValue = + (io.substrait.type.Type.ParameterIntegerValue) parameter; + return Type.Parameter.newBuilder().setInteger(intValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterEnumValue) { + io.substrait.type.Type.ParameterEnumValue enumValue = + (io.substrait.type.Type.ParameterEnumValue) parameter; + return Type.Parameter.newBuilder().setEnum(enumValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterStringValue) { + io.substrait.type.Type.ParameterStringValue stringValue = + (io.substrait.type.Type.ParameterStringValue) parameter; + return Type.Parameter.newBuilder().setString(stringValue.value()).build(); + } else { + throw new UnsupportedOperationException( + "Unsupported parameter type: " + parameter.getClass()); + } + } + @Override public BaseProtoTypes typeContainer(final boolean nullable) { return nullable ? NULLABLE : REQUIRED; } - private static class Types extends BaseProtoTypes { + /** + * Non-static inner class that can access the outer TypeProtoConverter instance. + * + *

This class must be non-static to access TypeProtoConverter.this.toProto() for converting + * nested type parameters (e.g., ParameterDataType containing another Type). Being non-static + * means instances are bound to a specific outer TypeProtoConverter instance, ensuring parameter + * conversions use the correct ExtensionCollector. + */ + private class Types extends BaseProtoTypes { public Types(final Type.Nullability nullability) { super(nullability); @@ -133,6 +176,20 @@ public Type userDefined(int ref) { Type.UserDefined.newBuilder().setTypeReference(ref).setNullability(nullability).build()); } + @Override + public Type userDefined( + int ref, java.util.List typeParameters) { + return wrap( + Type.UserDefined.newBuilder() + .setTypeReference(ref) + .setNullability(nullability) + .addAllTypeParameters( + typeParameters.stream() + .map(TypeProtoConverter.this::toProto) + .collect(java.util.stream.Collectors.toList())) + .build()); + } + @Override protected Type wrap(final Object o) { Type.Builder bldr = Type.newBuilder(); diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index 3defbf78f..b5f1dd4f1 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -1,8 +1,12 @@ package io.substrait; +import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; @@ -25,9 +29,22 @@ public abstract class TestBase { protected ProtoRelConverter protoRelConverter = new ProtoRelConverter(functionCollector, defaultExtensionCollection); + protected ExpressionProtoConverter expressionProtoConverter = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + + protected ProtoExpressionConverter protoExpressionConverter = + new ProtoExpressionConverter( + functionCollector, defaultExtensionCollection, EMPTY_TYPE, protoRelConverter); + protected void verifyRoundTrip(Rel rel) { io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); Rel relReturned = protoRelConverter.from(protoRel); assertEquals(rel, relReturned); } + + protected void verifyRoundTrip(Expression expression) { + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(expression); + Expression expressionReturned = protoExpressionConverter.from(protoExpression); + assertEquals(expression, expressionReturned); + } } diff --git a/core/src/test/java/io/substrait/plan/PlanConverterTest.java b/core/src/test/java/io/substrait/plan/PlanConverterTest.java index dd49cf207..97268e627 100644 --- a/core/src/test/java/io/substrait/plan/PlanConverterTest.java +++ b/core/src/test/java/io/substrait/plan/PlanConverterTest.java @@ -3,14 +3,22 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; import io.substrait.extension.AdvancedExtension; +import io.substrait.extension.SimpleExtension; import io.substrait.plan.Plan.Root; import io.substrait.relation.EmptyScan; +import io.substrait.relation.ImmutableVirtualTableScan; +import io.substrait.relation.VirtualTableScan; import io.substrait.type.NamedStruct; +import io.substrait.type.Type; import io.substrait.type.TypeCreator; import io.substrait.utils.StringHolder; import io.substrait.utils.StringHolderHandlingExtensionProtoConverter; import io.substrait.utils.StringHolderHandlingProtoExtensionConverter; +import java.util.Arrays; +import java.util.Collections; import org.junit.jupiter.api.Test; class PlanConverterTest { @@ -189,4 +197,127 @@ void planIncludingRelationWithAdvancedExtension() { assertEquals(plan, plan2); } + + /** + * Verifies that nested UserDefined types with type parameters share the same ExtensionCollector + * and don't create duplicate type references. Tests that a plan containing both a standalone + * UserDefined literal (point) and a parameterized UserDefined literal (vector) correctly + * registers both types in the extension collection without duplication. + */ + @Test + void nestedUserDefinedTypesShareExtensionCollector() { + // Define custom types: point and vector + String urn = "extension:test:nested_types"; + String yaml = + "---\n" + + "urn: " + + urn + + "\n" + + "types:\n" + + " - name: point\n" + + " structure:\n" + + " x: i32\n" + + " y: i32\n" + + " - name: vector\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " structure:\n" + + " x: T\n" + + " y: T\n" + + " z: T\n"; + + SimpleExtension.ExtensionCollection extensions = SimpleExtension.load("test.yaml", yaml); + + // Create type objects + Type pointType = Type.UserDefined.builder().nullable(false).urn(urn).name("point").build(); + + Type.Parameter pointTypeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder().type(pointType).build(); + + Type vectorOfPointType = + Type.UserDefined.builder() + .nullable(false) + .urn(urn) + .name("vector") + .addTypeParameters(pointTypeParam) + .build(); + + // Create literals + Expression.UserDefinedStructLiteral pointLiteral = + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList(ExpressionCreator.i32(false, 10), ExpressionCreator.i32(false, 20))); + + // Create vector literal: vector{(1,2), (3,4), (5,6)} + Expression.UserDefinedStructLiteral vectorOfPointLiteral = + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "vector", + Arrays.asList(pointTypeParam), + Arrays.asList( + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 1), ExpressionCreator.i32(false, 2))), + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 3), ExpressionCreator.i32(false, 4))), + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 5), ExpressionCreator.i32(false, 6))))); + + Type nullablePointType = + Type.UserDefined.builder().nullable(true).urn(urn).name("point").build(); + + Expression.UserDefinedStructLiteral nullablePointLiteral = + ExpressionCreator.userDefinedLiteralStruct( + true, + urn, + "point", + Collections.emptyList(), + Arrays.asList(ExpressionCreator.i32(false, 30), ExpressionCreator.i32(false, 40))); + + // Create virtual table with all three columns (nullable point, required point, required vector) + VirtualTableScan virtualTable = + ImmutableVirtualTableScan.builder() + .initialSchema( + NamedStruct.of( + Arrays.asList("nullable_point_col", "point_col", "vector_col"), + TypeCreator.REQUIRED.struct(nullablePointType, pointType, vectorOfPointType))) + .addRows( + ExpressionCreator.struct( + false, nullablePointLiteral, pointLiteral, vectorOfPointLiteral)) + .build(); + + Plan plan = Plan.builder().addRoots(Root.builder().input(virtualTable).build()).build(); + + PlanProtoConverter toProtoConverter = new PlanProtoConverter(); + io.substrait.proto.Plan protoPlan = toProtoConverter.toProto(plan); + + assertEquals(1, protoPlan.getExtensionUrnsCount(), "Should have exactly 1 extension URN"); + assertEquals( + 2, + protoPlan.getExtensionsCount(), + "Should have exactly 2 type extensions (point and vector), no duplicates"); + + ProtoPlanConverter fromProtoConverter = new ProtoPlanConverter(extensions); + Plan roundTrippedPlan = fromProtoConverter.from(protoPlan); + assertEquals(plan, roundTrippedPlan, "Plan should roundtrip correctly"); + } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index ccac93bcb..e70fe685e 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -3,23 +3,288 @@ import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.protobuf.Any; import io.substrait.TestBase; +import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.expression.proto.ProtoExpressionConverter; -import io.substrait.util.EmptyVisitationContext; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.relation.RelProtoConverter; import java.math.BigDecimal; +import java.util.Collections; import org.junit.jupiter.api.Test; public class LiteralRoundtripTest extends TestBase { + private static final String NESTED_TYPES_URN = "extension:io.substrait:test_nested_types"; + + private static final String NESTED_TYPES_YAML = + "---\n" + + "urn: " + + NESTED_TYPES_URN + + "\n" + + "types:\n" + + " - name: point\n" + + " structure:\n" + + " latitude: i32\n" + + " longitude: i32\n" + + " - name: triangle\n" + + " structure:\n" + + " p1: point\n" + + " p2: point\n" + + " p3: point\n" + + " - name: vector\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " structure:\n" + + " x: T\n" + + " y: T\n" + + " z: T\n" + + " - name: multi_param\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " - name: size\n" + + " type: integer\n" + + " - name: nullable\n" + + " type: boolean\n" + + " - name: encoding\n" + + " type: string\n" + + " - name: precision\n" + + " type: dataType\n" + + " - name: mode\n" + + " type: enum\n" + + " structure:\n" + + " value: T\n"; + + private static final SimpleExtension.ExtensionCollection NESTED_TYPES_EXTENSIONS = + SimpleExtension.load("nested_types.yaml", NESTED_TYPES_YAML); + + private static final ExtensionCollector NESTED_TYPES_FUNCTION_COLLECTOR = + new ExtensionCollector(); + private static final RelProtoConverter NESTED_TYPES_REL_PROTO_CONVERTER = + new RelProtoConverter(NESTED_TYPES_FUNCTION_COLLECTOR); + private static final ProtoRelConverter NESTED_TYPES_PROTO_REL_CONVERTER = + new ProtoRelConverter(NESTED_TYPES_FUNCTION_COLLECTOR, NESTED_TYPES_EXTENSIONS); + private static final ExpressionProtoConverter NESTED_TYPES_EXPRESSION_TO_PROTO = + new ExpressionProtoConverter( + NESTED_TYPES_FUNCTION_COLLECTOR, NESTED_TYPES_REL_PROTO_CONVERTER); + private static final ProtoExpressionConverter NESTED_TYPES_PROTO_TO_EXPRESSION = + new ProtoExpressionConverter( + NESTED_TYPES_FUNCTION_COLLECTOR, + NESTED_TYPES_EXTENSIONS, + EMPTY_TYPE, + NESTED_TYPES_PROTO_REL_CONVERTER); + + private void verifyNestedTypesRoundTrip(Expression expression) { + io.substrait.proto.Expression protoExpression = + NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(expression); + Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); + assertEquals(expression, result); + } + @Test void decimal() { io.substrait.expression.Expression.DecimalLiteral val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); - ProtoExpressionConverter from = - new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); + verifyRoundTrip(val); + } + + /** Verifies round-trip conversion of a simple user-defined type using Any representation. */ + @Test + void userDefinedLiteralWithAnyRepresentation() { + Any anyValue = + Any.pack(com.google.protobuf.StringValue.of("")); + + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue); + + verifyRoundTrip(val); + } + + /** Verifies round-trip conversion of a simple user-defined type using Struct representation. */ + @Test + void userDefinedLiteralWithStructRepresentation() { + java.util.List fields = + java.util.Arrays.asList( + ExpressionCreator.i32(false, 42), ExpressionCreator.i32(false, 100)); + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralStruct( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + fields); + + verifyRoundTrip(val); + } + + /** + * Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three + * point UDTs. Both outer and nested types use Struct representation. + */ + @Test + void nestedUserDefinedLiteralWithStructRepresentation() { + Expression.UserDefinedStructLiteral p1 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 0), ExpressionCreator.i32(false, 0))); + + Expression.UserDefinedStructLiteral p2 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 10), ExpressionCreator.i32(false, 0))); + + Expression.UserDefinedStructLiteral p3 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 5), ExpressionCreator.i32(false, 10))); + + Expression.UserDefinedStructLiteral triangle = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "triangle", + Collections.emptyList(), + java.util.Arrays.asList(p1, p2, p3)); + + verifyNestedTypesRoundTrip(triangle); + } + + /** + * Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three + * point UDTs. Both outer and nested types use Any representation. + */ + @Test + void nestedUserDefinedLiteralWithAnyRepresentation() { + Any triangleAny = + Any.pack(com.google.protobuf.StringValue.of("")); + + Expression.UserDefinedAnyLiteral triangle = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "triangle", Collections.emptyList(), triangleAny); + + verifyNestedTypesRoundTrip(triangle); + } + + /** + * Verifies round-trip conversion of nested user-defined types with mixed representations. The + * triangle UDT uses Struct representation while the nested point UDTs use Any representation. + */ + @Test + void mixedRepresentationNestedUserDefinedLiteral() { + Any anyValue = + Any.pack(com.google.protobuf.StringValue.of("")); + + // Create point UDTs using Any representation + Expression.UserDefinedAnyLiteral p1 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); + + Expression.UserDefinedAnyLiteral p2 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); + + Expression.UserDefinedAnyLiteral p3 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); + + // Create a "triangle" UDT using Struct representation, but with Any-encoded point fields + Expression.UserDefinedStructLiteral triangle = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "triangle", + Collections.emptyList(), + java.util.Arrays.asList(p1, p2, p3)); + + verifyNestedTypesRoundTrip(triangle); + } + + /** + * Verifies round-trip conversion of a parameterized user-defined type. Tests that type parameters + * are correctly preserved during serialization and deserialization. + */ + @Test + void userDefinedLiteralWithTypeParameters() { + // Create a type parameter for i32 + io.substrait.type.Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + + // Create a vector instance with fields (x: 1, y: 2, z: 3) + Expression.UserDefinedStructLiteral vectorI32 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "vector", + java.util.Arrays.asList(typeParam), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 1), + ExpressionCreator.i32(false, 2), + ExpressionCreator.i32(false, 3))); + + verifyNestedTypesRoundTrip(vectorI32); + } + + /** + * Verifies round-trip conversion of a user-defined type with all parameter types. Tests that all + * parameter kinds (type, integer, boolean, string, null, enum) are correctly preserved during + * serialization and deserialization. + */ + @Test + void userDefinedLiteralWithAllParameterTypes() { + io.substrait.type.Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + + io.substrait.type.Type.Parameter intParam = + io.substrait.type.ImmutableType.ParameterIntegerValue.builder().value(100L).build(); + + io.substrait.type.Type.Parameter boolParam = + io.substrait.type.ImmutableType.ParameterBooleanValue.builder().value(true).build(); + + io.substrait.type.Type.Parameter stringParam = + io.substrait.type.ImmutableType.ParameterStringValue.builder().value("utf8").build(); + + io.substrait.type.Type.Parameter nullParam = io.substrait.type.Type.ParameterNull.INSTANCE; + + io.substrait.type.Type.Parameter enumParam = + io.substrait.type.ImmutableType.ParameterEnumValue.builder().value("FAST").build(); + + Expression.UserDefinedStructLiteral multiParam = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "multi_param", + java.util.Arrays.asList( + typeParam, intParam, boolParam, stringParam, nullParam, enumParam), + java.util.Arrays.asList(ExpressionCreator.i32(false, 42))); + + verifyNestedTypesRoundTrip(multiParam); } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index 71de9a7d5..fe1e7a965 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -37,7 +37,8 @@ import io.substrait.expression.Expression.TimestampLiteral; import io.substrait.expression.Expression.TimestampTZLiteral; import io.substrait.expression.Expression.UUIDLiteral; -import io.substrait.expression.Expression.UserDefinedLiteral; +import io.substrait.expression.Expression.UserDefinedAnyLiteral; +import io.substrait.expression.Expression.UserDefinedStructLiteral; import io.substrait.expression.Expression.VarCharLiteral; import io.substrait.expression.Expression.WindowFunctionInvocation; import io.substrait.expression.ExpressionVisitor; @@ -188,9 +189,15 @@ public String visit(StructLiteral expr, EmptyVisitationContext context) throws R } @Override - public String visit(UserDefinedLiteral expr, EmptyVisitationContext context) + public String visit(UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws RuntimeException { - return ""; + return ""; + } + + @Override + public String visit(UserDefinedStructLiteral expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; } @Override diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java index f3b34f6c2..de67a3ee8 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java @@ -91,7 +91,7 @@ public static List explain(io.substrait.plan.Plan plan) { /** * Explains the Sustrait relation * - * @param plan Subsrait relation + * @param rel Subsrait relation * @return List of strings; typically these would then be logged or sent to stdout */ public static List explain(io.substrait.relation.Rel rel) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index 8dcfbf9e0..1485bd2f5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -2,6 +2,9 @@ import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.SubstraitRelNodeConverter.Context; +import io.substrait.isthmus.expression.AggregateFunctionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.plan.Plan; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; @@ -94,7 +97,27 @@ protected RelBuilder createRelBuilder(CalciteSchema schema) { *

Override this method to customize the {@link SubstraitRelNodeConverter}. */ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { - return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder); + ScalarFunctionConverter scalarFunctionConverter = + new ScalarFunctionConverter( + extensions.scalarFunctions(), + java.util.Collections.emptyList(), + typeFactory, + typeConverter); + AggregateFunctionConverter aggregateFunctionConverter = + new AggregateFunctionConverter( + extensions.aggregateFunctions(), + java.util.Collections.emptyList(), + typeFactory, + typeConverter); + WindowFunctionConverter windowFunctionConverter = + new WindowFunctionConverter(extensions.windowFunctions(), typeFactory); + return new SubstraitRelNodeConverter( + typeFactory, + relBuilder, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter); } /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 3406de7de..bd044461b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -44,15 +44,18 @@ public class CallConverters { * {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent and store {@link * Expression.UserDefinedLiteral}s within Calcite. * - *

When converting from Substrait to Calcite, the {@link Expression.UserDefinedLiteral#value()} - * is stored within a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link - * org.apache.calcite.rex.RexLiteral} and then re-interpreted to have the correct type. + *

When converting from Substrait to Calcite, the user-defined literal value is stored either + * as a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link + * org.apache.calcite.rex.RexLiteral} (for ANY-encoded values) or a {@link SqlKind#ROW} (for + * struct-encoded values) and then re-interpreted to have the correct user-defined type. * - *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedLiteral, + *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedAnyLiteral, + * SubstraitRelNodeConverter.Context)} and {@link + * ExpressionRexConverter#visit(Expression.UserDefinedStructLiteral, * SubstraitRelNodeConverter.Context)} for this conversion. * - *

When converting from Calcite to Substrait, this call converter extracts the {@link - * Expression.UserDefinedLiteral} that was stored. + *

When converting from Calcite to Substrait, this call converter extracts the stored {@link + * Expression.UserDefinedLiteral}. */ public static Function REINTERPRET = typeConverter -> @@ -70,15 +73,58 @@ public class CallConverters { Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand; Type.UserDefined t = (Type.UserDefined) type; - return Expression.UserDefinedLiteral.builder() + // The binary literal contains the serialized protobuf Any - just parse it directly + try { + com.google.protobuf.Any anyValue = + com.google.protobuf.Any.parseFrom(literal.value().toByteArray()); + + return Expression.UserDefinedAnyLiteral.builder() + .nullable(t.nullable()) + .urn(t.urn()) + .name(t.name()) + .addAllTypeParameters(t.typeParameters()) + .value(anyValue) + .build(); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw new IllegalStateException("Failed to parse UserDefinedAnyLiteral value", e); + } + } else if (operand instanceof Expression.StructLiteral + && type instanceof Type.UserDefined) { + Expression.StructLiteral structLiteral = (Expression.StructLiteral) operand; + Type.UserDefined t = (Type.UserDefined) type; + + return Expression.UserDefinedStructLiteral.builder() + .nullable(t.nullable()) .urn(t.urn()) .name(t.name()) - .value(literal.value()) + .addAllTypeParameters(t.typeParameters()) + .addAllFields(structLiteral.fields()) .build(); } return null; }; + /** Converts Calcite ROW constructors into Substrait struct literals. */ + public static SimpleCallConverter ROW = + (call, visitor) -> { + if (call.getKind() != SqlKind.ROW) { + return null; + } + + List operands = + call.getOperands().stream().map(visitor).collect(java.util.stream.Collectors.toList()); + if (!operands.stream().allMatch(expr -> expr instanceof Expression.Literal)) { + throw new IllegalArgumentException("ROW operands must be literals."); + } + + List literals = + operands.stream() + .map(expr -> (Expression.Literal) expr) + .collect(java.util.stream.Collectors.toList()); + + return ExpressionCreator.struct(call.getType().isNullable(), literals); + }; + // public static SimpleCallConverter OrAnd(FunctionConverter c) { // return (call, visitor) -> { // if (call.getKind() != SqlKind.AND && call.getKind() != SqlKind.OR) { @@ -139,6 +185,7 @@ public static List defaults(TypeConverter typeConverter) { return ImmutableList.of( new FieldSelectionConverter(typeConverter), CallConverters.CASE, + CallConverters.ROW, CallConverters.CAST.apply(typeConverter), CallConverters.REINTERPRET.apply(typeConverter), new LiteralConstructorConverter(typeConverter)); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 2b8052889..162db8302 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -109,7 +109,7 @@ public RexNode visit(Expression.NullLiteral expr, Context context) throws Runtim } @Override - public RexNode visit(Expression.UserDefinedLiteral expr, Context context) + public RexNode visit(Expression.UserDefinedAnyLiteral expr, Context context) throws RuntimeException { RexLiteral binaryLiteral = rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteArray())); @@ -117,6 +117,19 @@ public RexNode visit(Expression.UserDefinedLiteral expr, Context context) return rexBuilder.makeReinterpretCast(type, binaryLiteral, rexBuilder.makeLiteral(false)); } + @Override + public RexNode visit(Expression.UserDefinedStructLiteral expr, Context context) + throws RuntimeException { + RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType()); + RexNode structValue = toStruct(expr.fields(), expr.nullable(), context); + return rexBuilder.makeReinterpretCast(type, structValue, rexBuilder.makeLiteral(false)); + } + + @Override + public RexNode visit(Expression.StructLiteral expr, Context context) throws RuntimeException { + return toStruct(expr.fields(), expr.nullable(), context); + } + @Override public RexNode visit(Expression.BoolLiteral expr, Context context) throws RuntimeException { return rexBuilder.makeLiteral(expr.value()); @@ -249,6 +262,26 @@ public RexNode visit(PrecisionTimestampTZLiteral expr, Context context) throws R typeConverter.toCalcite(typeFactory, expr.getType())); } + private RexNode toStruct( + List fields, boolean nullable, Context context) { + List fieldNodes = + fields.stream().map(f -> f.accept(this, context)).collect(Collectors.toList()); + + RelDataTypeFactory.Builder rowBuilder = typeFactory.builder(); + IntStream.range(0, fields.size()) + .forEach( + i -> + rowBuilder.add( + "field" + i, typeConverter.toCalcite(typeFactory, fields.get(i).getType()))); + + RelDataType rowType = rowBuilder.build(); + if (nullable) { + rowType = typeFactory.createTypeWithNullability(rowType, true); + } + + return rexBuilder.makeCall(rowType, SqlStdOperatorTable.ROW, fieldNodes); + } + private TimestampString getTimestampString(long microSec) { return getTimestampString(microSec, 6); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index c80f6a4ba..66eff0769 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -7,7 +7,6 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression.UserDefinedLiteral; import io.substrait.expression.ExpressionCreator; -import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.FunctionMappings; @@ -16,9 +15,7 @@ import io.substrait.isthmus.utils.UserTypeFactory; import io.substrait.proto.Expression; import io.substrait.proto.Expression.Literal.Builder; -import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; -import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.io.IOException; @@ -587,7 +584,9 @@ void customTypesInFunctionsRoundtrip() { void customTypesLiteralInFunctionsRoundtrip() { Builder bldr = Expression.Literal.newBuilder(); Any anyValue = Any.pack(bldr.setI32(10).build()); - UserDefinedLiteral val = ExpressionCreator.userDefinedLiteral(false, URN, "a_type", anyValue); + UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue); Rel rel1 = b.project( @@ -599,10 +598,24 @@ void customTypesLiteralInFunctionsRoundtrip() { RelNode calciteRel = substraitToCalcite.convert(rel1); Rel rel2 = calciteToSubstrait.apply(calciteRel); assertEquals(rel1, rel2); + } - ExtensionCollector extensionCollector = new ExtensionCollector(); - io.substrait.proto.Rel protoRel = new RelProtoConverter(extensionCollector).toProto(rel1); - Rel rel3 = new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); - assertEquals(rel1, rel3); + @Test + void customNullableUserDefinedLiteralRoundtrip() { + Builder bldr = Expression.Literal.newBuilder(); + Any anyValue = Any.pack(bldr.setI32(10).build()); + UserDefinedLiteral nullableLiteral = + ExpressionCreator.userDefinedLiteralAny( + true, URN, "a_type", java.util.Collections.emptyList(), anyValue); + + Rel rel = + b.project( + input -> List.of(nullableLiteral), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + Rel relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java new file mode 100644 index 000000000..83e3ba103 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java @@ -0,0 +1,300 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.protobuf.Any; +import com.google.protobuf.StringValue; +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.expression.AggregateFunctionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; +import io.substrait.isthmus.utils.UserTypeFactory; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.Test; + +public class UserDefinedLiteralRoundtripTest extends PlanTestBase { + + private static final String NESTED_TYPES_URN = "extension:io.substrait:test_nested_types"; + + private static final String NESTED_TYPES_YAML = + "---\n" + + "urn: " + + NESTED_TYPES_URN + + "\n" + + "types:\n" + + " - name: point\n" + + " structure:\n" + + " latitude: i32\n" + + " longitude: i32\n" + + " - name: triangle\n" + + " structure:\n" + + " p1: point\n" + + " p2: point\n" + + " p3: point\n" + + " - name: vector\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " structure:\n" + + " x: T\n" + + " y: T\n" + + " z: T\n" + + " - name: multi_param\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " - name: size\n" + + " type: integer\n" + + " - name: nullable\n" + + " type: boolean\n" + + " - name: encoding\n" + + " type: string\n" + + " - name: precision\n" + + " type: dataType\n" + + " - name: mode\n" + + " type: enum\n" + + " structure:\n" + + " value: T\n"; + + private static final SimpleExtension.ExtensionCollection NESTED_TYPES_EXTENSIONS = + SimpleExtension.load("nested_types.yaml", NESTED_TYPES_YAML); + + private final SubstraitBuilder builder = new SubstraitBuilder(NESTED_TYPES_EXTENSIONS); + + private final Map userTypeFactories = + Map.of( + "point", new UserTypeFactory(NESTED_TYPES_URN, "point"), + "triangle", new UserTypeFactory(NESTED_TYPES_URN, "triangle"), + "vector", new UserTypeFactory(NESTED_TYPES_URN, "vector"), + "multi_param", new UserTypeFactory(NESTED_TYPES_URN, "multi_param")); + + private final UserTypeMapper userTypeMapper = + new UserTypeMapper() { + @Override + public @Nullable Type toSubstrait(RelDataType relDataType) { + return userTypeFactories.values().stream() + .filter(factory -> factory.isTypeFromFactory(relDataType)) + .findFirst() + .map( + factory -> + factory.createSubstrait( + relDataType.isNullable(), factory.getTypeParameters(relDataType))) + .orElse(null); + } + + @Override + public @Nullable RelDataType toCalcite(Type.UserDefined type) { + if (!type.urn().equals(NESTED_TYPES_URN)) { + return null; + } + UserTypeFactory factory = userTypeFactories.get(type.name()); + if (factory == null) { + return null; + } + + return factory.createCalcite(type.nullable(), type.typeParameters()); + } + }; + + private final TypeConverter typeConverter = new TypeConverter(userTypeMapper); + + private final ScalarFunctionConverter scalarFunctionConverter = + new ScalarFunctionConverter( + NESTED_TYPES_EXTENSIONS.scalarFunctions(), + Collections.emptyList(), + typeFactory, + typeConverter); + + private final AggregateFunctionConverter aggregateFunctionConverter = + new AggregateFunctionConverter( + NESTED_TYPES_EXTENSIONS.aggregateFunctions(), + Collections.emptyList(), + typeFactory, + typeConverter); + + private final WindowFunctionConverter windowFunctionConverter = + new WindowFunctionConverter(NESTED_TYPES_EXTENSIONS.windowFunctions(), typeFactory); + + private final SubstraitToCalcite substraitToCalcite = + new SubstraitToCalcite(NESTED_TYPES_EXTENSIONS, typeFactory, typeConverter); + + private final SubstraitRelVisitor calciteToSubstrait = + new SubstraitRelVisitor( + typeFactory, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter, + ImmutableFeatureBoard.builder().build()); + + private void assertRoundTrip(Expression.UserDefinedLiteral literal) { + Rel rel = + builder.project( + input -> List.of(literal), + builder.remap(1), + builder.namedScan(List.of("example"), List.of("udt_col"), List.of(literal.getType()))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + Rel relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + private Expression.UserDefinedStructLiteral pointStructLiteral(int latitude, int longitude) { + return ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, latitude), ExpressionCreator.i32(false, longitude))); + } + + private Expression.UserDefinedStructLiteral triangleStruct( + Expression.UserDefinedLiteral p1, + Expression.UserDefinedLiteral p2, + Expression.UserDefinedLiteral p3) { + return ExpressionCreator.userDefinedLiteralStruct( + false, NESTED_TYPES_URN, "triangle", Collections.emptyList(), Arrays.asList(p1, p2, p3)); + } + + private Expression.UserDefinedAnyLiteral pointAnyLiteral(String value) { + return ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), Any.pack(StringValue.of(value))); + } + + private Expression.UserDefinedStructLiteral vectorStructLiteral( + List params, + Expression.Literal x, + Expression.Literal y, + Expression.Literal z) { + return ExpressionCreator.userDefinedLiteralStruct( + false, NESTED_TYPES_URN, "vector", params, Arrays.asList(x, y, z)); + } + + @Test + void anyEncodedUdtRoundTrip() { + Expression.UserDefinedLiteral literal = + ExpressionCreator.userDefinedLiteralAny( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + Any.pack(StringValue.of(""))); + + assertRoundTrip(literal); + } + + @Test + void structEncodedUdtRoundTrip() { + assertRoundTrip(pointStructLiteral(42, 100)); + } + + @Test + void nestedStructEncodedUdtRoundTrip() { + assertRoundTrip( + triangleStruct( + pointStructLiteral(0, 0), pointStructLiteral(10, 0), pointStructLiteral(5, 10))); + } + + @Test + void nestedMixedEncodingsRoundTrip() { + // Mix encodings: struct, any, struct. + assertRoundTrip( + triangleStruct( + pointStructLiteral(1, 2), pointAnyLiteral("p2-any"), pointStructLiteral(3, 4))); + } + + @Test + void parameterizedUdtRoundTrip() { + Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + + Expression.UserDefinedLiteral literal = + vectorStructLiteral( + Collections.singletonList(typeParam), + ExpressionCreator.i32(false, 1), + ExpressionCreator.i32(false, 2), + ExpressionCreator.i32(false, 3)); + + assertRoundTrip(literal); + } + + @Test + void parameterizedUdtAllParamKindsRoundTrip() { + Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + + Type.Parameter intParam = + io.substrait.type.ImmutableType.ParameterIntegerValue.builder().value(100L).build(); + + Type.Parameter boolParam = + io.substrait.type.ImmutableType.ParameterBooleanValue.builder().value(true).build(); + + Type.Parameter stringParam = + io.substrait.type.ImmutableType.ParameterStringValue.builder().value("utf8").build(); + + Type.Parameter nullParam = io.substrait.type.Type.ParameterNull.INSTANCE; + + Type.Parameter enumParam = + io.substrait.type.ImmutableType.ParameterEnumValue.builder().value("FAST").build(); + + Expression.UserDefinedLiteral literal = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "multi_param", + Arrays.asList(typeParam, intParam, boolParam, stringParam, nullParam, enumParam), + Arrays.asList(ExpressionCreator.i32(false, 42))); + + assertRoundTrip(literal); + } + + @Test + void multipleParameterizedUdtInstancesRoundTrip() { + Type.Parameter i32Param = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + Type.Parameter fp64Param = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.FP64.builder().nullable(false).build()) + .build(); + + Expression.UserDefinedLiteral vecI32 = + vectorStructLiteral( + Collections.singletonList(i32Param), + ExpressionCreator.i32(false, 1), + ExpressionCreator.i32(false, 2), + ExpressionCreator.i32(false, 3)); + + Expression.UserDefinedLiteral vecFp64 = + vectorStructLiteral( + Collections.singletonList(fp64Param), + ExpressionCreator.fp64(false, 1.1), + ExpressionCreator.fp64(false, 2.2), + ExpressionCreator.fp64(false, 3.3)); + + Rel rel = + builder.project( + input -> Arrays.asList(vecI32, vecFp64), builder.remap(0, 1), builder.emptyScan()); + + RelNode calciteRel = substraitToCalcite.convert(rel); + Rel relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index 2c90f133d..2a70b87f7 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -1,7 +1,9 @@ package io.substrait.isthmus.utils; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeImpl; import org.apache.calcite.sql.type.SqlTypeName; @@ -19,34 +21,65 @@ public class UserTypeFactory { public UserTypeFactory(String urn, String name) { this.urn = urn; this.name = name; - this.N = new InnerType(true, name); - this.R = new InnerType(false, name); + this.N = new InnerType(urn, name, true, Collections.emptyList()); + this.R = new InnerType(urn, name, false, Collections.emptyList()); } public RelDataType createCalcite(boolean nullable) { - if (nullable) { - return N; - } else { - return R; + return createCalcite(nullable, Collections.emptyList()); + } + + public RelDataType createCalcite(boolean nullable, List typeParameters) { + if (typeParameters.isEmpty()) { + return nullable ? N : R; } + + return new InnerType(urn, name, nullable, typeParameters); } public Type createSubstrait(boolean nullable) { - return TypeCreator.of(nullable).userDefined(urn, name); + return createSubstrait(nullable, Collections.emptyList()); + } + + public Type createSubstrait(boolean nullable, List typeParameters) { + return Type.UserDefined.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .addAllTypeParameters(typeParameters) + .build(); } public boolean isTypeFromFactory(RelDataType type) { - return type == N || type == R; + // We may return cached instances (N/R) or fresh InnerType instances with parameters. + // Use instanceof to recognize any of them and match by urn/name so custom UDT mappings work. + if (type instanceof InnerType) { + InnerType inner = (InnerType) type; + return urn.equals(inner.urn) && name.equals(inner.name); + } + return false; + } + + public List getTypeParameters(RelDataType type) { + if (type instanceof InnerType) { + return ((InnerType) type).typeParameters; + } + return Collections.emptyList(); } private static class InnerType extends RelDataTypeImpl { private final boolean nullable; + private final String urn; private final String name; + private final List typeParameters; - private InnerType(boolean nullable, String name) { - computeDigest(); - this.nullable = nullable; + private InnerType( + String urn, String name, boolean nullable, List typeParameters) { + this.urn = urn; this.name = name; + this.nullable = nullable; + this.typeParameters = Collections.unmodifiableList(typeParameters); + computeDigest(); } @Override @@ -61,7 +94,13 @@ public SqlTypeName getSqlTypeName() { @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { - sb.append(name); + sb.append(urn).append(":").append(name); + + if (!typeParameters.isEmpty()) { + sb.append("<"); + sb.append(typeParameters.stream().map(Object::toString).collect(Collectors.joining(","))); + sb.append(">"); + } } } } diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 5377f4257..10c134658 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -77,7 +77,13 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { } override def visit( - expr: Expression.UserDefinedLiteral, + expr: Expression.UserDefinedAnyLiteral, + context: EmptyVisitationContext): String = { + expr.toString + } + + override def visit( + expr: Expression.UserDefinedStructLiteral, context: EmptyVisitationContext): String = { expr.toString } diff --git a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala index 5f7137b14..d1b7a32a6 100644 --- a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala +++ b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala @@ -65,9 +65,11 @@ class DefaultExpressionVisitor[T] context: EmptyVisitationContext): T = e.accept(this, context) + override def visit(expr: Expression.UserDefinedAnyLiteral, context: EmptyVisitationContext): T = + visitFallback(expr, context) + override def visit( - userDefinedLiteral: Expression.UserDefinedLiteral, - context: EmptyVisitationContext): T = { - visitFallback(userDefinedLiteral, context) - } + expr: Expression.UserDefinedStructLiteral, + context: EmptyVisitationContext): T = + visitFallback(expr, context) }