diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 072507295..d190542f8 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E { return visitFallback(expr, context); } + @Override + public O visit(Expression.UserDefinedAny expr, C context) throws E { + return visitFallback(expr, context); + } + + @Override + public O visit(Expression.UserDefinedStruct expr, C context) throws E { + return visitFallback(expr, context); + } + @Override public O visit(Expression.Switch expr, C context) throws E { return visitFallback(expr, context); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 42c3c5118..1b0a8362d 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -662,21 +662,96 @@ public R accept( } } + /** + * Base interface for user-defined literals. + * + *

User-defined literals can be encoded in one of two ways as per the Substrait spec: + * + *

+ * + * @see UserDefinedAny + * @see UserDefinedStruct + */ + interface UserDefinedLiteral extends Literal { + String urn(); + + String name(); + + List typeParameters(); + } + + /** + * User-defined literal with value encoded as {@code google.protobuf.Any}. + * + *

This encoding allows for arbitrary binary data to be stored in the literal value. + */ @Value.Immutable - abstract class UserDefinedLiteral implements Literal { - public abstract ByteString value(); + abstract class UserDefinedAny implements UserDefinedLiteral { + @Override + public abstract String urn(); + + @Override + public abstract String name(); + + @Override + public abstract List typeParameters(); + + public abstract com.google.protobuf.Any value(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); + } + + public static ImmutableExpression.UserDefinedAny.Builder builder() { + return ImmutableExpression.UserDefinedAny.builder(); + } + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + + /** + * User-defined literal with value encoded as {@code Literal.Struct}. + * + *

This encoding uses a structured list of fields to represent the literal value. + */ + @Value.Immutable + abstract class UserDefinedStruct implements UserDefinedLiteral { + @Override public abstract String urn(); + @Override public abstract String name(); @Override - public Type getType() { - return Type.withNullability(nullable()).userDefined(urn(), name()); + public abstract List typeParameters(); + + public abstract List fields(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); } - public static ImmutableExpression.UserDefinedLiteral.Builder builder() { - return ImmutableExpression.UserDefinedLiteral.builder(); + public static ImmutableExpression.UserDefinedStruct.Builder builder() { + return ImmutableExpression.UserDefinedStruct.builder(); } @Override diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index adf157d7b..2f924bef8 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -286,13 +286,51 @@ public static Expression.StructLiteral struct( return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build(); } - public static Expression.UserDefinedLiteral userDefinedLiteral( - boolean nullable, String urn, String name, Any value) { - return Expression.UserDefinedLiteral.builder() + /** + * Create a UserDefinedAny with google.protobuf.Any representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) + * @param value the value, encoded as google.protobuf.Any + */ + public static Expression.UserDefinedAny userDefinedLiteralAny( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + Any value) { + return Expression.UserDefinedAny.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .addAllTypeParameters(typeParameters) + .value(value) + .build(); + } + + /** + * Create a UserDefinedStruct with Struct representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) + * @param fields the fields, as a list of Literal values + */ + public static Expression.UserDefinedStruct userDefinedLiteralStruct( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + java.util.List fields) { + return Expression.UserDefinedStruct.builder() .nullable(nullable) .urn(urn) .name(name) - .value(value.toByteString()) + .addAllTypeParameters(typeParameters) + .addAllFields(fields) .build(); } diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index d64cab48c..7cec9b953 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -62,7 +62,9 @@ public interface ExpressionVisitor { - try { - bldr.setNullable(expr.nullable()) - .setUserDefined( - Expression.Literal.UserDefined.newBuilder() - .setTypeReference(typeReference) - .setValue(Any.parseFrom(expr.value()))) - .build(); - } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException(e); + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters(expr.typeParameters()) + .setValue(expr.value()); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); + }); + } + + @Override + public Expression visit( + io.substrait.expression.Expression.UserDefinedStruct expr, EmptyVisitationContext context) { + int typeReference = + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); + return lit( + bldr -> { + Expression.Literal.Struct.Builder structBuilder = Expression.Literal.Struct.newBuilder(); + for (io.substrait.expression.Expression.Literal field : expr.fields()) { + structBuilder.addFields(toLiteral(field)); } + + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters(expr.typeParameters()) + .setStruct(structBuilder.build()); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); }); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 8f95cdf07..847fcae55 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -492,10 +492,36 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { { io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral = literal.getUserDefined(); + SimpleExtension.Type type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); - return ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue()); + String urn = type.urn(); + String name = type.name(); + + switch (userDefinedLiteral.getValCase()) { + case VALUE: + return ExpressionCreator.userDefinedLiteralAny( + literal.getNullable(), + urn, + name, + userDefinedLiteral.getTypeParametersList(), + userDefinedLiteral.getValue()); + case STRUCT: + return ExpressionCreator.userDefinedLiteralStruct( + literal.getNullable(), + urn, + name, + userDefinedLiteral.getTypeParametersList(), + userDefinedLiteral.getStruct().getFieldsList().stream() + .map(this::from) + .collect(Collectors.toList())); + case VAL_NOT_SET: + throw new IllegalStateException( + "UserDefined literal has no value (neither 'value' nor 'struct' is set)"); + default: + throw new IllegalStateException( + "Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase()); + } } default: throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 89aad954e..31214878c 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -22,6 +22,7 @@ public class DefaultExtensionCatalog { "extension:io.substrait:functions_rounding_decimal"; public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; + public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types"; public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION = loadDefaultCollection(); @@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { .map(c -> String.format("/functions_%s.yaml", c)) .collect(Collectors.toList()); + defaultFiles.add("/extension_types.yaml"); + return SimpleExtension.load(defaultFiles); } } diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 57132a940..68395ac0d 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -203,9 +203,15 @@ public Optional visit(Expression.StructLiteral expr, EmptyVisitation return visitLiteral(expr); } + @Override + public Optional visit(Expression.UserDefinedAny expr, EmptyVisitationContext context) + throws E { + return visitLiteral(expr); + } + @Override public Optional visit( - Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E { + Expression.UserDefinedStruct expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index aaf97aa12..7ef2d75a7 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -393,6 +393,23 @@ abstract class UserDefined implements Type { public abstract String name(); + /** + * Returns the type parameters for this user-defined type. + * + *

Type parameters are used to represent parameterized/generic types, such as {@code + * List} or {@code Map}. Each parameter in the list represents a type argument + * that specializes the generic user-defined type. + * + *

For example, a user-defined type {@code MyList} parameterized by {@code i32} would have + * one type parameter containing the {@code i32} type definition. + * + * @return a list of type parameters, or an empty list if this type is not parameterized + */ + @Value.Default + public java.util.List typeParameters() { + return java.util.Collections.emptyList(); + } + public static ImmutableType.UserDefined.Builder builder() { return ImmutableType.UserDefined.builder(); } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 691d4bce5..67d7bc9b5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -165,6 +165,6 @@ public final T visit(final Type.Map expr) { public final T visit(final Type.UserDefined expr) { int ref = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); - return typeContainer(expr).userDefined(ref); + return typeContainer(expr).userDefined(ref, expr.typeParameters()); } } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 6a1bc3186..1009fe52a 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -131,6 +131,9 @@ public final T struct(T... types) { public abstract T userDefined(int ref); + public abstract T userDefined( + int ref, java.util.List typeParameters); + protected abstract T wrap(Object o); protected abstract I i(int integerValue); diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 4e0caa7c2..137c1fba3 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -262,6 +262,13 @@ public ParameterizedType userDefined(int ref) { "User defined types are not supported in Parameterized Types for now"); } + @Override + public ParameterizedType userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Parameterized Types for now"); + } + @Override protected ParameterizedType wrap(final Object o) { ParameterizedType.Builder bldr = ParameterizedType.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 95d42328a..ee77e1445 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -90,7 +90,13 @@ public Type from(io.substrait.proto.Type type) { { io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); - return n(userDefined.getNullability()).userDefined(t.urn(), t.name()); + boolean nullable = isNullable(userDefined.getNullability()); + return io.substrait.type.Type.UserDefined.builder() + .nullable(nullable) + .urn(t.urn()) + .name(t.name()) + .typeParameters(userDefined.getTypeParametersList()) + .build(); } case USER_DEFINED_TYPE_REFERENCE: throw new UnsupportedOperationException("Unsupported user defined reference: " + type); diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index 96cddd395..a3412a9e3 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -355,6 +355,13 @@ public DerivationExpression userDefined(int ref) { "User defined types are not supported in Derivation Expressions for now"); } + @Override + public DerivationExpression userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Derivation Expressions for now"); + } + @Override protected DerivationExpression wrap(final Object o) { DerivationExpression.Builder bldr = DerivationExpression.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 2d0ed0ffc..7cb98263f 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -133,6 +133,17 @@ public Type userDefined(int ref) { Type.UserDefined.newBuilder().setTypeReference(ref).setNullability(nullability).build()); } + @Override + public Type userDefined( + int ref, java.util.List typeParameters) { + return wrap( + Type.UserDefined.newBuilder() + .setTypeReference(ref) + .setNullability(nullability) + .addAllTypeParameters(typeParameters) + .build()); + } + @Override protected Type wrap(final Object o) { Type.Builder bldr = Type.newBuilder(); diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index 3defbf78f..b5f1dd4f1 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -1,8 +1,12 @@ package io.substrait; +import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; @@ -25,9 +29,22 @@ public abstract class TestBase { protected ProtoRelConverter protoRelConverter = new ProtoRelConverter(functionCollector, defaultExtensionCollection); + protected ExpressionProtoConverter expressionProtoConverter = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + + protected ProtoExpressionConverter protoExpressionConverter = + new ProtoExpressionConverter( + functionCollector, defaultExtensionCollection, EMPTY_TYPE, protoRelConverter); + protected void verifyRoundTrip(Rel rel) { io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); Rel relReturned = protoRelConverter.from(protoRel); assertEquals(rel, relReturned); } + + protected void verifyRoundTrip(Expression expression) { + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(expression); + Expression expressionReturned = protoExpressionConverter.from(protoExpression); + assertEquals(expression, expressionReturned); + } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index ccac93bcb..8e4b13cbe 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -1,13 +1,10 @@ package io.substrait.type.proto; -import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; -import static org.junit.jupiter.api.Assertions.assertEquals; - +import com.google.protobuf.Any; import io.substrait.TestBase; +import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.proto.ExpressionProtoConverter; -import io.substrait.expression.proto.ProtoExpressionConverter; -import io.substrait.util.EmptyVisitationContext; +import io.substrait.extension.DefaultExtensionCatalog; import java.math.BigDecimal; import org.junit.jupiter.api.Test; @@ -17,9 +14,45 @@ public class LiteralRoundtripTest extends TestBase { void decimal() { io.substrait.expression.Expression.DecimalLiteral val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); - ProtoExpressionConverter from = - new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); + verifyRoundTrip(val); + } + + @Test + void userDefinedLiteralWithAnyRepresentation() { + // Create a struct literal inline representing a point with latitude=42, longitude=100 + io.substrait.proto.Expression.Literal.Struct pointStruct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(42)) + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(100)) + .build(); + io.substrait.proto.Expression.Literal innerLiteral = + io.substrait.proto.Expression.Literal.newBuilder().setStruct(pointStruct).build(); + Any anyValue = Any.pack(innerLiteral); + + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue); + + verifyRoundTrip(val); + } + + @Test + void userDefinedLiteralWithStructRepresentation() { + java.util.List fields = + java.util.Arrays.asList( + ExpressionCreator.i32(false, 42), ExpressionCreator.i32(false, 100)); + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralStruct( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + fields); + + verifyRoundTrip(val); } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index 71de9a7d5..cee84d13d 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -37,7 +37,8 @@ import io.substrait.expression.Expression.TimestampLiteral; import io.substrait.expression.Expression.TimestampTZLiteral; import io.substrait.expression.Expression.UUIDLiteral; -import io.substrait.expression.Expression.UserDefinedLiteral; +import io.substrait.expression.Expression.UserDefinedAny; +import io.substrait.expression.Expression.UserDefinedStruct; import io.substrait.expression.Expression.VarCharLiteral; import io.substrait.expression.Expression.WindowFunctionInvocation; import io.substrait.expression.ExpressionVisitor; @@ -188,9 +189,14 @@ public String visit(StructLiteral expr, EmptyVisitationContext context) throws R } @Override - public String visit(UserDefinedLiteral expr, EmptyVisitationContext context) + public String visit(UserDefinedAny expr, EmptyVisitationContext context) throws RuntimeException { + return ""; + } + + @Override + public String visit(UserDefinedStruct expr, EmptyVisitationContext context) throws RuntimeException { - return ""; + return ""; } @Override diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java index f3b34f6c2..de67a3ee8 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java @@ -91,7 +91,7 @@ public static List explain(io.substrait.plan.Plan plan) { /** * Explains the Sustrait relation * - * @param plan Subsrait relation + * @param rel Subsrait relation * @return List of strings; typically these would then be logged or sent to stdout */ public static List explain(io.substrait.relation.Rel rel) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index 932b8f6d8..c332dfd19 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -363,8 +363,7 @@ public RelDataType visit(Type.UserDefined expr) throws RuntimeException { if (type != null) { return type; } - throw new UnsupportedOperationException( - String.format("Unable to map user-defined type: %s", expr)); + return io.substrait.isthmus.type.SubstraitUserDefinedType.from(expr); } private boolean n(NullableType type) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 3406de7de..f45ca86d3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -5,7 +5,7 @@ import io.substrait.expression.ExpressionCreator; import io.substrait.isthmus.CallConverter; import io.substrait.isthmus.TypeConverter; -import io.substrait.type.Type; +import io.substrait.isthmus.type.SubstraitUserDefinedType; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -41,18 +41,26 @@ public class CallConverters { }; /** - * {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent and store {@link - * Expression.UserDefinedLiteral}s within Calcite. + * {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent {@link + * Expression.UserDefinedAny} literals within Calcite. * - *

When converting from Substrait to Calcite, the {@link Expression.UserDefinedLiteral#value()} - * is stored within a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link - * org.apache.calcite.rex.RexLiteral} and then re-interpreted to have the correct type. + *

When converting from Substrait to Calcite, UserDefinedAny literals are serialized to binary + * and stored as {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link + * org.apache.calcite.rex.RexLiteral}, then re-interpreted to have a custom {@link + * SubstraitUserDefinedType.SubstraitUserDefinedAnyType} that preserves all metadata including + * type parameters. * - *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedLiteral, - * SubstraitRelNodeConverter.Context)} for this conversion. + *

Note: {@link Expression.UserDefinedStruct} literals are NOT handled via REINTERPRET. + * Instead, they are represented as Calcite ROW literals with {@link + * SubstraitUserDefinedType.SubstraitUserDefinedStructType} and converted via {@link + * LiteralConverter}. * - *

When converting from Calcite to Substrait, this call converter extracts the {@link - * Expression.UserDefinedLiteral} that was stored. + *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedAny, + * SubstraitRelNodeConverter.Context)} for the UserDefinedAny conversion. + * + *

When converting from Calcite back to Substrait, this call converter deserializes the binary + * value and reconstructs the UserDefinedAny literal with all metadata preserved (including type + * parameters). */ public static Function REINTERPRET = typeConverter -> @@ -61,20 +69,28 @@ public class CallConverters { return null; } Expression operand = visitor.apply(call.getOperands().get(0)); - Type type = typeConverter.toSubstrait(call.getType()); - // For now, we only support handling of SqlKind.REINTEPRETET for the case of stored - // user-defined literals if (operand instanceof Expression.FixedBinaryLiteral - && type instanceof Type.UserDefined) { + && call.getType() instanceof SubstraitUserDefinedType.SubstraitUserDefinedAnyType) { Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand; - Type.UserDefined t = (Type.UserDefined) type; - - return Expression.UserDefinedLiteral.builder() - .urn(t.urn()) - .name(t.name()) - .value(literal.value()) - .build(); + SubstraitUserDefinedType.SubstraitUserDefinedAnyType customType = + (SubstraitUserDefinedType.SubstraitUserDefinedAnyType) call.getType(); + + try { + com.google.protobuf.Any anyValue = + com.google.protobuf.Any.parseFrom(literal.value().toByteArray()); + + return Expression.UserDefinedAny.builder() + .urn(customType.getUrn()) + .name(customType.getName()) + .typeParameters(customType.getTypeParameters()) + .value(anyValue) + .nullable(customType.isNullable()) + .build(); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + "Failed to parse UserDefinedAny literal value", e); + } } return null; }; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 2b8052889..9d7ea630b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -109,12 +109,110 @@ public RexNode visit(Expression.NullLiteral expr, Context context) throws Runtim } @Override - public RexNode visit(Expression.UserDefinedLiteral expr, Context context) - throws RuntimeException { + public RexNode visit(Expression.UserDefinedAny expr, Context context) throws RuntimeException { + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType customType = + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType.from( + expr.getType()); + RexLiteral binaryLiteral = rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteArray())); - RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType()); - return rexBuilder.makeReinterpretCast(type, binaryLiteral, rexBuilder.makeLiteral(false)); + return rexBuilder.makeReinterpretCast(customType, binaryLiteral, rexBuilder.makeLiteral(false)); + } + + @Override + public RexNode visit(Expression.UserDefinedStruct expr, Context context) throws RuntimeException { + return toUserDefinedStructLiteral(expr, context); + } + + private RexLiteral toUserDefinedStructLiteral(Expression.UserDefinedStruct expr, Context context) { + java.util.List fieldTypes = + new java.util.ArrayList<>(expr.fields().size()); + java.util.List fieldLiterals = new java.util.ArrayList<>(expr.fields().size()); + + for (Expression.Literal field : expr.fields()) { + fieldTypes.add(toStructFieldType(field)); + fieldLiterals.add(literalToRexLiteral(field, context)); + } + + java.util.List fieldNames = + java.util.stream.IntStream.range(0, expr.fields().size()) + .mapToObj(i -> "f" + i) + .collect(java.util.stream.Collectors.toList()); + + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedStructType customType = + new io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedStructType( + expr.urn(), + expr.name(), + expr.typeParameters(), + expr.nullable(), + fieldTypes, + fieldNames); + + return (RexLiteral) rexBuilder.makeLiteral(fieldLiterals, customType, false); + } + + private org.apache.calcite.rel.type.RelDataType toStructFieldType(Expression.Literal field) { + if (field instanceof Expression.UserDefinedAny) { + io.substrait.type.Type.UserDefined userDefinedType = + (io.substrait.type.Type.UserDefined) field.getType(); + return io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType.from( + userDefinedType); + } + return typeConverter.toCalcite(typeFactory, field.getType()); + } + + private RexLiteral toStructFieldLiteral(Expression.Literal field, Context context) { + return literalToRexLiteral(field, context); + } + + private RexLiteral literalToRexLiteral(Expression.Literal literal, Context context) { + if (literal instanceof Expression.UserDefinedAny) { + Expression.UserDefinedAny userDefinedAny = (Expression.UserDefinedAny) literal; + org.apache.calcite.avatica.util.ByteString bytes = + new org.apache.calcite.avatica.util.ByteString(userDefinedAny.value().toByteArray()); + return rexBuilder.makeBinaryLiteral(bytes); + } + + if (literal instanceof Expression.UserDefinedStruct) { + return toUserDefinedStructLiteral((Expression.UserDefinedStruct) literal, context); + } + + if (literal instanceof Expression.StructLiteral) { + java.util.List fieldValues = + new java.util.ArrayList<>(((Expression.StructLiteral) literal).fields().size()); + for (Expression.Literal child : ((Expression.StructLiteral) literal).fields()) { + fieldValues.add(literalToRexLiteral(child, context)); + } + return (RexLiteral) + rexBuilder.makeLiteral( + fieldValues, typeConverter.toCalcite(typeFactory, literal.getType()), false); + } + + if (literal instanceof Expression.ListLiteral) { + java.util.List elements = + new java.util.ArrayList<>(((Expression.ListLiteral) literal).values().size()); + for (Expression.Literal child : ((Expression.ListLiteral) literal).values()) { + elements.add(literalToRexLiteral(child, context)); + } + return (RexLiteral) + rexBuilder.makeLiteral( + elements, typeConverter.toCalcite(typeFactory, literal.getType()), false); + } + + if (literal instanceof Expression.EmptyListLiteral) { + return (RexLiteral) + rexBuilder.makeLiteral( + java.util.Collections.emptyList(), + typeConverter.toCalcite(typeFactory, literal.getType()), + false); + } + + RexNode rexField = literal.accept(this, context); + if (!(rexField instanceof RexLiteral)) { + throw new IllegalArgumentException( + "Expected literal when converting UserDefinedStruct field but found " + rexField); + } + return (RexLiteral) rexField; } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java index 02cb8a116..fd17b8860 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java @@ -1,6 +1,8 @@ package io.substrait.isthmus.expression; +import com.google.protobuf.Any; import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.isthmus.TypeConverter; @@ -14,21 +16,22 @@ import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.time.temporal.ChronoField; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.util.DateString; import org.apache.calcite.util.NlsString; import org.apache.calcite.util.TimeString; import org.apache.calcite.util.TimestampString; public class LiteralConverter { - // TODO: Handle conversion of user-defined type literals - static final DateTimeFormatter CALCITE_LOCAL_DATE_FORMATTER = DateTimeFormatter.ISO_LOCAL_DATE; static final DateTimeFormatter CALCITE_LOCAL_TIME_FORMATTER = new DateTimeFormatterBuilder() @@ -195,9 +198,31 @@ public Expression.Literal convert(RexLiteral literal) { case ROW: { - List literals = (List) literal.getValue(); - return ExpressionCreator.struct( - n, literals.stream().map(this::convert).collect(Collectors.toList())); + @SuppressWarnings("unchecked") + List fieldNodes = (List) literal.getValue(); + List relFields = literal.getType().getFieldList(); + ArrayList convertedFields = new ArrayList<>(fieldNodes.size()); + for (int i = 0; i < fieldNodes.size(); i++) { + convertedFields.add(convertStructField(fieldNodes.get(i), relFields.get(i).getType())); + } + + if (literal.getType() + instanceof + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedStructType) { + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedStructType + udtType = + (io.substrait.isthmus.type.SubstraitUserDefinedType + .SubstraitUserDefinedStructType) + literal.getType(); + return ExpressionCreator.userDefinedLiteralStruct( + udtType.isNullable(), + udtType.getUrn(), + udtType.getName(), + udtType.getTypeParameters(), + convertedFields); + } + + return ExpressionCreator.struct(n, convertedFields); } case ARRAY: @@ -235,4 +260,58 @@ public static byte[] padRightIfNeeded(byte[] value, int length) { System.arraycopy(value, 0, newArray, 0, value.length); return newArray; } + + private Expression.Literal convertStructField(RexNode fieldNode, RelDataType expectedType) { + if (expectedType + instanceof io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType) { + return convertUserDefinedAnyStructField( + fieldNode, + (io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType) + expectedType); + } + + if (!(fieldNode instanceof RexLiteral)) { + throw new UnsupportedOperationException( + "Expected literal struct field but found " + fieldNode); + } + return convert((RexLiteral) fieldNode); + } + + private Expression.Literal convertUserDefinedAnyStructField( + RexNode fieldNode, + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType expectedType) { + if (!(fieldNode instanceof RexLiteral)) { + throw new UnsupportedOperationException( + "Expected literal for UserDefinedAny struct field but found " + fieldNode); + } + + RexLiteral literal = (RexLiteral) fieldNode; + if (literal.isNull()) { + return ExpressionCreator.typedNull( + Type.UserDefined.builder() + .urn(expectedType.getUrn()) + .name(expectedType.getName()) + .typeParameters(expectedType.getTypeParameters()) + .nullable(true) + .build()); + } + + org.apache.calcite.avatica.util.ByteString bytes = + literal.getValueAs(org.apache.calcite.avatica.util.ByteString.class); + if (bytes == null) { + throw new IllegalArgumentException( + "Expected binary literal for UserDefinedAny struct field but value was null"); + } + try { + Any anyValue = Any.parseFrom(bytes.getBytes()); + return ExpressionCreator.userDefinedLiteralAny( + expectedType.isNullable(), + expectedType.getUrn(), + expectedType.getName(), + expectedType.getTypeParameters(), + anyValue); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException("Failed to parse UserDefinedAny literal", e); + } + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java new file mode 100644 index 000000000..aa155a476 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java @@ -0,0 +1,251 @@ +package io.substrait.isthmus.type; + +import com.google.protobuf.TextFormat; +import io.substrait.type.Type; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rel.type.RelDataTypeFieldImpl; +import org.apache.calcite.rel.type.RelDataTypeImpl; +import org.apache.calcite.sql.type.SqlTypeName; + +/** + * Custom Calcite {@link RelDataType} for Substrait user-defined types. + * + *

This type preserves all UDT metadata (URN, name, type parameters) during Calcite roundtrips. + * It is used when converting types without literal context. For literals, specialized subclasses + * provide representation-specific handling: + * + *

+ * + * @see SubstraitUserDefinedAnyType + * @see SubstraitUserDefinedStructType + * @see io.substrait.expression.Expression.UserDefinedAny + * @see io.substrait.expression.Expression.UserDefinedStruct + */ +public class SubstraitUserDefinedType extends RelDataTypeImpl { + + private final String urn; + private final String name; + private final List typeParameters; + private final boolean nullable; + + public SubstraitUserDefinedType( + String urn, + String name, + List typeParameters, + boolean nullable) { + this.urn = urn; + this.name = name; + this.typeParameters = + typeParameters != null ? typeParameters : java.util.Collections.emptyList(); + this.nullable = nullable; + computeDigest(); + } + + public String getUrn() { + return urn; + } + + public String getName() { + return name; + } + + public List getTypeParameters() { + return typeParameters; + } + + @Override + public boolean isNullable() { + return nullable; + } + + @Override + public SqlTypeName getSqlTypeName() { + return SqlTypeName.OTHER; + } + + /** Converts this Calcite type back to a Substrait {@link Type.UserDefined}. */ + public Type.UserDefined toSubstraitType() { + return Type.UserDefined.builder() + .urn(urn) + .name(name) + .typeParameters(typeParameters) + .nullable(nullable) + .build(); + } + + /** Creates a SubstraitUserDefinedType from a Substrait Type.UserDefined. */ + public static SubstraitUserDefinedType from(io.substrait.type.Type.UserDefined type) { + return new SubstraitUserDefinedType( + type.urn(), type.name(), type.typeParameters(), type.nullable()); + } + + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + appendDigest(sb); + } + + protected void appendDigest(StringBuilder sb) { + sb.append(urn).append("::").append(name); + appendTypeParameters(sb, typeParameters); + } + + private static void appendTypeParameters( + StringBuilder sb, java.util.List parameters) { + if (parameters.isEmpty()) { + return; + } + sb.append("<"); + sb.append( + parameters.stream() + .map(SubstraitUserDefinedType::formatParameter) + .collect(Collectors.joining(","))); + sb.append(">"); + } + + private static String formatParameter(io.substrait.proto.Type.Parameter parameter) { + return TextFormat.shortDebugString(parameter); + } + + /** + * Custom Calcite type representing a Substrait {@link + * io.substrait.expression.Expression.UserDefinedAny} type. + * + *

This type wraps opaque binary data (protobuf Any) and preserves all UDT metadata including + * type parameters during Calcite roundtrips. + * + *

Note: The actual value (protobuf Any) is not stored in the type itself - it's stored in the + * literal. This type only carries the metadata (URN, name, type parameters). + * + *

{@link io.substrait.expression.Expression.UserDefinedAny UserDefinedAny} literals use this + * type when passing through Calcite, as they need to be serialized to binary with REINTERPRET. + * {@link io.substrait.expression.Expression.UserDefinedStruct UserDefinedStruct} literals use + * {@link SubstraitUserDefinedStructType} instead to preserve field structure. + * + * @see SubstraitUserDefinedStructType + * @see io.substrait.expression.Expression.UserDefinedAny + */ + public static class SubstraitUserDefinedAnyType extends SubstraitUserDefinedType { + + public SubstraitUserDefinedAnyType( + String urn, + String name, + List typeParameters, + boolean nullable) { + super(urn, name, typeParameters, nullable); + } + + /** Creates a SubstraitUserDefinedAnyType from a Substrait Type.UserDefined. */ + public static SubstraitUserDefinedAnyType from(io.substrait.type.Type.UserDefined type) { + return new SubstraitUserDefinedAnyType( + type.urn(), type.name(), type.typeParameters(), type.nullable()); + } + } + + /** + * Custom Calcite type representing a Substrait {@link + * io.substrait.expression.Expression.UserDefinedStruct} type. + * + *

This type represents a structured UDT with explicitly defined fields. Unlike {@link + * SubstraitUserDefinedAnyType}, the fields are accessible and can be represented as a Calcite + * STRUCT/ROW type with additional UDT metadata (URN, name, type parameters). + * + *

{@link io.substrait.expression.Expression.UserDefinedStruct UserDefinedStruct} literals use + * this type when passing through Calcite, preserving field structure and enabling field access. + * The fields are converted to Calcite literals and wrapped in a ROW type with synthetic field + * names (f0, f1, f2, etc.). + * + * @see SubstraitUserDefinedAnyType + * @see io.substrait.expression.Expression.UserDefinedStruct + */ + public static class SubstraitUserDefinedStructType extends SubstraitUserDefinedType { + + private final List fieldTypes; + private final List fieldNames; + + public SubstraitUserDefinedStructType( + String urn, + String name, + List typeParameters, + boolean nullable, + List fieldTypes, + List fieldNames) { + super(urn, name, typeParameters, nullable); + if (fieldTypes.size() != fieldNames.size()) { + throw new IllegalArgumentException("Field types and names must have same length"); + } + this.fieldTypes = fieldTypes; + this.fieldNames = fieldNames; + } + + @Override + public List getFieldList() { + java.util.List fields = new java.util.ArrayList<>(); + for (int i = 0; i < fieldTypes.size(); i++) { + fields.add(new RelDataTypeFieldImpl(fieldNames.get(i), i, fieldTypes.get(i))); + } + return fields; + } + + @Override + public int getFieldCount() { + return fieldTypes.size(); + } + + @Override + public RelDataTypeField getField(String fieldName, boolean caseSensitive, boolean elideRecord) { + for (int i = 0; i < fieldNames.size(); i++) { + String name = fieldNames.get(i); + if (caseSensitive ? name.equals(fieldName) : name.equalsIgnoreCase(fieldName)) { + return new RelDataTypeFieldImpl(name, i, fieldTypes.get(i)); + } + } + return null; + } + + public List getFieldTypes() { + return fieldTypes; + } + + @Override + public List getFieldNames() { + return fieldNames; + } + + @Override + public SqlTypeName getSqlTypeName() { + // Can be considered as ROW since it has structure + return SqlTypeName.ROW; + } + + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + appendDigest(sb); + if (withDetail && fieldNames != null) { + sb.append("("); + sb.append( + java.util.stream.IntStream.range(0, fieldNames.size()) + .mapToObj(i -> fieldNames.get(i) + ": " + fieldTypes.get(i)) + .collect(java.util.stream.Collectors.joining(", "))); + sb.append(")"); + } + } + + /** + * Creates a SubstraitUserDefinedStructType from a Substrait Type.UserDefined and field + * information. + */ + public static SubstraitUserDefinedStructType from( + io.substrait.type.Type.UserDefined type, + List fieldTypes, + List fieldNames) { + return new SubstraitUserDefinedStructType( + type.urn(), type.name(), type.typeParameters(), type.nullable(), fieldTypes, fieldNames); + } + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index c80f6a4ba..61a0603d0 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -7,7 +7,6 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression.UserDefinedLiteral; import io.substrait.expression.ExpressionCreator; -import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.FunctionMappings; @@ -15,10 +14,7 @@ import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.isthmus.utils.UserTypeFactory; import io.substrait.proto.Expression; -import io.substrait.proto.Expression.Literal.Builder; -import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; -import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.io.IOException; @@ -585,24 +581,20 @@ void customTypesInFunctionsRoundtrip() { @Test void customTypesLiteralInFunctionsRoundtrip() { - Builder bldr = Expression.Literal.newBuilder(); + Expression.Literal.Builder bldr = Expression.Literal.newBuilder(); Any anyValue = Any.pack(bldr.setI32(10).build()); - UserDefinedLiteral val = ExpressionCreator.userDefinedLiteral(false, URN, "a_type", anyValue); + UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue); - Rel rel1 = + Rel originalRel = b.project( input -> List.of(b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), val)), b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - RelNode calciteRel = substraitToCalcite.convert(rel1); - Rel rel2 = calciteToSubstrait.apply(calciteRel); - assertEquals(rel1, rel2); - - ExtensionCollector extensionCollector = new ExtensionCollector(); - io.substrait.proto.Rel protoRel = new RelProtoConverter(extensionCollector).toProto(rel1); - Rel rel3 = new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); - assertEquals(rel1, rel3); + assertCalciteRoundtrip( + originalRel, substraitToCalcite, calciteToSubstrait, extensionCollection); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index cce58e207..e1e4db3ba 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -275,12 +275,50 @@ protected void assertFullRoundTripWithIdentityProjectionWorkaround( } /** - * Verifies that the given POJO can be converted: + * Verifies that a relation roundtrips correctly through Calcite conversion. This is isthmus' core + * responsibility: Substrait ↔ Calcite. * - *

    - *
  • From POJO to Proto and back - *
  • From POJO to Calcite and back - *
+ * @param rel the relation to roundtrip + */ + protected void assertCalciteRoundtrip(Rel rel) { + assertCalciteRoundtrip(rel, null, null, null); + } + + /** + * Verifies that a relation roundtrips correctly through Calcite conversion. This is isthmus' core + * responsibility: Substrait ↔ Calcite. + * + * @param rel the relation to roundtrip + * @param substraitToCalcite custom SubstraitToCalcite converter, or null to use default + * @param substraitRelVisitor custom SubstraitRelVisitor converter, or null to use default + * @param customExtensions custom extension collection, or null to use default + */ + protected void assertCalciteRoundtrip( + Rel rel, + @org.jspecify.annotations.Nullable SubstraitToCalcite substraitToCalcite, + @org.jspecify.annotations.Nullable SubstraitRelVisitor substraitRelVisitor, + SimpleExtension.@org.jspecify.annotations.Nullable ExtensionCollection customExtensions) { + SimpleExtension.ExtensionCollection exts = + customExtensions != null ? customExtensions : extensions; + + // Substrait -> Calcite + SubstraitToCalcite s2c = + substraitToCalcite != null ? substraitToCalcite : new SubstraitToCalcite(exts, typeFactory); + RelNode calcite = s2c.convert(rel); + + // Calcite -> Substrait + io.substrait.relation.Rel roundtripped = + substraitRelVisitor != null + ? substraitRelVisitor.apply(calcite) + : SubstraitRelVisitor.convert(calcite, exts); + + assertEquals(rel, roundtripped); + } + + /** + * Verifies that a relation can be converted through both proto and Calcite roundtrips. + * + * @param pojo1 the relation to roundtrip */ protected void assertFullRoundTrip(Rel pojo1) { // TODO: reuse the Plan.Root based assertFullRoundTrip by generating names @@ -315,6 +353,25 @@ protected void assertFullRoundTrip(Rel pojo1) { * */ protected void assertFullRoundTrip(Plan.Root pojo1) { + assertFullRoundTrip(pojo1, null, null); + } + + /** + * Verifies that the given POJO can be converted: + * + *
    + *
  • From POJO to Proto and back + *
  • From POJO to Calcite and back + *
+ * + * @param pojo1 the plan root to roundtrip + * @param substraitToCalcite custom SubstraitToCalcite converter, or null to use default + * @param substraitRelVisitor custom SubstraitRelVisitor converter, or null to use default + */ + protected void assertFullRoundTrip( + Plan.Root pojo1, + @org.jspecify.annotations.Nullable SubstraitToCalcite substraitToCalcite, + @org.jspecify.annotations.Nullable SubstraitRelVisitor substraitRelVisitor) { ExtensionCollector extensionCollector = new ExtensionCollector(); // Substrait POJO 1 -> Substrait Proto @@ -328,10 +385,17 @@ protected void assertFullRoundTrip(Plan.Root pojo1) { assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite - RelRoot calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); + SubstraitToCalcite s2c = + substraitToCalcite != null + ? substraitToCalcite + : new SubstraitToCalcite(extensions, typeFactory); + RelRoot calcite = s2c.convert(pojo2); // Calcite -> Substrait POJO 3 - io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, extensions); + io.substrait.plan.Plan.Root pojo3 = + substraitRelVisitor != null + ? SubstraitRelVisitor.convert(calcite, substraitRelVisitor) + : SubstraitRelVisitor.convert(calcite, extensions); // Verify that POJOs are the same assertEquals(pojo1, pojo3); diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java new file mode 100644 index 000000000..adf96ec7a --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java @@ -0,0 +1,752 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import com.google.protobuf.Any; +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression.UserDefinedLiteral; +import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.expression.AggregateFunctionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; +import io.substrait.isthmus.utils.UserTypeFactory; +import io.substrait.proto.Expression; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.relation.Rel; +import io.substrait.relation.RelProtoConverter; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.tools.RelBuilder; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.Test; + +/** + * Tests for User-Defined Type literals, including both UserDefinedAny (protobuf Any-based) and + * UserDefinedStruct (struct-based) encoding strategies. + * + *

These tests verify proto serialization/deserialization and Calcite roundtrips of UDT literals, + * using standard types from extension_types.yaml (point and line). + */ +public class UserDefinedTypeLiteralTest extends PlanTestBase { + + final SubstraitBuilder b = new SubstraitBuilder(DefaultExtensionCatalog.DEFAULT_COLLECTION); + + // Create user-defined types using standard types from extension_types.yaml + static final UserTypeFactory pointTypeFactory = + new UserTypeFactory(DefaultExtensionCatalog.EXTENSION_TYPES, "point"); + static final UserTypeFactory lineTypeFactory = + new UserTypeFactory(DefaultExtensionCatalog.EXTENSION_TYPES, "line"); + + // Mapper for user-defined types + static final UserTypeMapper userTypeMapper = + new UserTypeMapper() { + @Nullable + @Override + public Type toSubstrait(RelDataType relDataType) { + if (pointTypeFactory.isTypeFromFactory(relDataType)) { + return TypeCreator.of(relDataType.isNullable()) + .userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point"); + } + if (lineTypeFactory.isTypeFromFactory(relDataType)) { + return TypeCreator.of(relDataType.isNullable()) + .userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line"); + } + return null; + } + + @Nullable + @Override + public RelDataType toCalcite(Type.UserDefined type) { + if (type.urn().equals(DefaultExtensionCatalog.EXTENSION_TYPES)) { + if (type.name().equals("point")) { + return pointTypeFactory.createCalcite(type.nullable()); + } + if (type.name().equals("line")) { + return lineTypeFactory.createCalcite(type.nullable()); + } + } + return null; + } + }; + + TypeConverter typeConverter = new TypeConverter(userTypeMapper); + + // Create Function Converters that can handle the user-defined types + ScalarFunctionConverter scalarFunctionConverter = + new ScalarFunctionConverter( + DefaultExtensionCatalog.DEFAULT_COLLECTION.scalarFunctions(), + List.of(), + typeFactory, + typeConverter); + AggregateFunctionConverter aggregateFunctionConverter = + new AggregateFunctionConverter( + DefaultExtensionCatalog.DEFAULT_COLLECTION.aggregateFunctions(), + List.of(), + typeFactory, + typeConverter); + WindowFunctionConverter windowFunctionConverter = + new WindowFunctionConverter( + DefaultExtensionCatalog.DEFAULT_COLLECTION.windowFunctions(), typeFactory); + + final SubstraitToCalcite substraitToCalcite = + new CustomSubstraitToCalcite( + DefaultExtensionCatalog.DEFAULT_COLLECTION, typeFactory, typeConverter); + + // Create a SubstraitRelVisitor that uses the custom Function Converters + final SubstraitRelVisitor calciteToSubstrait = + new SubstraitRelVisitor( + typeFactory, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter, + ImmutableFeatureBoard.builder().build()); + + // Create a SubstraitToCalcite converter that has access to the custom Function Converters + class CustomSubstraitToCalcite extends SubstraitToCalcite { + + public CustomSubstraitToCalcite( + SimpleExtension.ExtensionCollection extensions, + RelDataTypeFactory typeFactory, + TypeConverter typeConverter) { + super(extensions, typeFactory, typeConverter); + } + + @Override + protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { + return new SubstraitRelNodeConverter( + typeFactory, + relBuilder, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter); + } + } + + /** + * Verifies proto roundtrip for a relation. This test class needs this method locally since it's + * testing proto serialization (core's responsibility) but must reside in isthmus to access + * Calcite integration components. + */ + private void verifyProtoRoundTrip(Rel rel) { + ExtensionCollector functionCollector = new ExtensionCollector(); + RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); + ProtoRelConverter protoRelConverter = + new ProtoRelConverter(functionCollector, DefaultExtensionCatalog.DEFAULT_COLLECTION); + + io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + Rel relReturned = protoRelConverter.from(protoRel); + assertEquals(rel, relReturned); + } + + @Test + void multipleDifferentUserDefinedAnyTypesProtoRoundtrip() { + // Test that UserDefinedAny literals with different type names - proto only + // point wraps struct with two i32 fields, line wraps struct with two point fields + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStruct = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(42)) + .addFields(Expression.Literal.newBuilder().setI32(100)) + .build(); + Any anyValue1 = Any.pack(bldr1.setStruct(pointStruct).build()); + UserDefinedLiteral pointLit = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue1); + + Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); + Expression.Literal.Struct lineStruct = + Expression.Literal.Struct.newBuilder() + .addFields(bldr1.build()) // reuse point struct as start + .addFields(bldr1.build()) // reuse point struct as end + .build(); + Any anyValue2 = Any.pack(bldr2.setStruct(lineStruct).build()); + UserDefinedLiteral lineLit = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "line", + java.util.Collections.emptyList(), + anyValue2); + + Rel originalRel = + b.project( + input -> List.of(pointLit, lineLit), + b.remap(1, 2), // Select both expressions + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void singleUserDefinedAnyCalciteRoundtrip() { + // Test that a single UserDefinedAny literal can roundtrip through Calcite + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStruct = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(42)) + .addFields(Expression.Literal.newBuilder().setI32(100)) + .build(); + Any anyValue1 = Any.pack(bldr1.setStruct(pointStruct).build()); + UserDefinedLiteral pointLit = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue1); + + Rel originalRel = + b.project( + input -> List.of(pointLit), + b.remap(1), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + @Test + void singleUserDefinedStructCalciteRoundtrip() { + // Test that a single UserDefinedStruct literal can roundtrip through Calcite + io.substrait.expression.Expression.UserDefinedStruct val = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 42)) + .addFields(ExpressionCreator.i32(false, 100)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(val), + b.remap(1), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + @Test + void nestedUserDefinedStructCalciteRoundtrip() { + io.substrait.expression.Expression.UserDefinedStruct startPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 5)) + .addFields(ExpressionCreator.i32(false, 15)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct endPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 25)) + .addFields(ExpressionCreator.i32(false, 35)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct lineStructLit = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) + .build(); + + Rel originalRel = + b.project( + input -> List.of(lineStructLit), + b.remap(1), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + void multipleDifferentUserDefinedAnyTypesCalciteRoundtrip() { + // Test that multiple UserDefinedAny literals with different types can roundtrip through Calcite + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStruct = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(42)) + .addFields(Expression.Literal.newBuilder().setI32(100)) + .build(); + Any anyValue1 = Any.pack(bldr1.setStruct(pointStruct).build()); + UserDefinedLiteral pointLit = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue1); + + Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); + Expression.Literal.Struct lineStruct = + Expression.Literal.Struct.newBuilder() + .addFields(bldr1.build()) + .addFields(bldr1.build()) + .build(); + Any anyValue2 = Any.pack(bldr2.setStruct(lineStruct).build()); + UserDefinedLiteral lineLit = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "line", + java.util.Collections.emptyList(), + anyValue2); + + Rel originalRel = + b.project( + input -> List.of(pointLit, lineLit), + b.remap(1, 2), // Select both expressions + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + @Test + void userDefinedStructWithPrimitivesProtoRoundtrip() { + // Test UserDefinedStruct with primitive field types - proto roundtrip only + io.substrait.expression.Expression.UserDefinedStruct val = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 42)) + .addFields(ExpressionCreator.i32(false, 100)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(val), + b.remap(1), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void userDefinedStructWithNestedStructProtoRoundtrip() { + // Test UserDefinedStruct with nested UDT fields - proto roundtrip only + // line contains nested point UDT fields + io.substrait.expression.Expression.UserDefinedStruct startPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 10)) + .addFields(ExpressionCreator.i32(false, 20)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct endPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 30)) + .addFields(ExpressionCreator.i32(false, 40)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct line = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) + .build(); + + Rel originalRel = + b.project( + input -> List.of(line), + b.remap(1), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void userDefinedStructWithNestedAnyCalciteRoundtrip() { + // Mix struct-encoded and Any-encoded fields inside a single UserDefinedStruct literal + Expression.Literal.Builder pointBuilder = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStructProto = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(5)) + .addFields(Expression.Literal.newBuilder().setI32(10)) + .build(); + Any pointAny = Any.pack(pointBuilder.setStruct(pointStructProto).build()); + UserDefinedLiteral pointAnyLiteral = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + pointAny); + + io.substrait.expression.Expression.UserDefinedStruct pointStructLiteral = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 20)) + .addFields(ExpressionCreator.i32(false, 30)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct line = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(pointAnyLiteral) + .addFields(pointStructLiteral) + .build(); + + Rel originalRel = + b.project( + input -> List.of(line), + b.remap(1), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + @Test + void multipleUserDefinedStructDifferentStructuresProtoRoundtrip() { + // Test multiple UserDefinedStruct types with different struct schemas + // point: {latitude: i32, longitude: i32} + // line: {start: point, end: point} + io.substrait.expression.Expression.UserDefinedStruct pointStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 42)) + .addFields(ExpressionCreator.i32(false, 100)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct startPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 10)) + .addFields(ExpressionCreator.i32(false, 20)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct endPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 30)) + .addFields(ExpressionCreator.i32(false, 40)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct lineStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) + .build(); + + Rel originalRel = + b.project( + input -> List.of(pointStruct, lineStruct), + b.remap(2), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void sameUdTypeDifferentEncodingsCalciteRoundtrip() { + // Validate that "line" UDT survives Calcite roundtrip in both Any and Struct encodings + Expression.Literal.Builder lineBuilder = Expression.Literal.newBuilder(); + Expression.Literal.Builder pointBuilder = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStructProto = + Expression.Literal.Struct.newBuilder() + .addFields(pointBuilder.clear().setI32(5).build()) + .addFields(pointBuilder.clear().setI32(15).build()) + .build(); + Expression.Literal.Struct lineStructProto = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setStruct(pointStructProto).build()) + .addFields(Expression.Literal.newBuilder().setStruct(pointStructProto).build()) + .build(); + Any lineAnyValue = Any.pack(lineBuilder.setStruct(lineStructProto).build()); + UserDefinedLiteral lineAny = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "line", + java.util.Collections.emptyList(), + lineAnyValue); + + io.substrait.expression.Expression.UserDefinedStruct startPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 1)) + .addFields(ExpressionCreator.i32(false, 2)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct endPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 3)) + .addFields(ExpressionCreator.i32(false, 4)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct lineStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) + .build(); + + Rel relWithAny = + b.project( + input -> List.of(lineAny), + b.remap(1), + b.namedScan( + List.of("example_any"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line")))); + Rel relWithStruct = + b.project( + input -> List.of(lineStruct), + b.remap(1), + b.namedScan( + List.of("example_struct"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line")))); + + Rel roundtrippedAny = calciteToSubstrait.apply(substraitToCalcite.convert(relWithAny)); + assertInstanceOf(io.substrait.relation.Project.class, roundtrippedAny); + io.substrait.relation.Project anyProject = (io.substrait.relation.Project) roundtrippedAny; + assertEquals(1, anyProject.getExpressions().size()); + assertInstanceOf( + io.substrait.expression.Expression.UserDefinedAny.class, + anyProject.getExpressions().get(0)); + + Rel roundtrippedStruct = calciteToSubstrait.apply(substraitToCalcite.convert(relWithStruct)); + assertInstanceOf(io.substrait.relation.Project.class, roundtrippedStruct); + io.substrait.relation.Project structProject = + (io.substrait.relation.Project) roundtrippedStruct; + assertEquals(1, structProject.getExpressions().size()); + assertInstanceOf( + io.substrait.expression.Expression.UserDefinedStruct.class, + structProject.getExpressions().get(0)); + } + + @Test + void intermixedUserDefinedAnyAndStructProtoRoundtrip() { + // Test intermixing UserDefinedAny and UserDefinedStruct in the same query + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStruct1 = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(10)) + .addFields(Expression.Literal.newBuilder().setI32(20)) + .build(); + Any anyValue1 = Any.pack(bldr1.setStruct(pointStruct1).build()); + UserDefinedLiteral anyLit1 = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue1); + + io.substrait.expression.Expression.UserDefinedStruct structLit1 = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 123)) + .addFields(ExpressionCreator.i32(false, 456)) + .build(); + + Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStruct2 = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(30)) + .addFields(Expression.Literal.newBuilder().setI32(40)) + .build(); + Any anyValue2 = Any.pack(bldr2.setStruct(pointStruct2).build()); + UserDefinedLiteral anyLit2 = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue2); + + io.substrait.expression.Expression.UserDefinedStruct structLit2 = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 789)) + .addFields(ExpressionCreator.i32(false, 101)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(anyLit1, structLit1, anyLit2, structLit2), + b.remap(4), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void multipleDifferentUDTTypesWithAnyAndStructProtoRoundtrip() { + // Test multiple different UDT type names (point, line) with both Any and Struct + Expression.Literal.Builder pointBldr = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStruct = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(42)) + .addFields(Expression.Literal.newBuilder().setI32(100)) + .build(); + Any pointAny = Any.pack(pointBldr.setStruct(pointStruct).build()); + UserDefinedLiteral pointAny1 = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + pointAny); + + io.substrait.expression.Expression.UserDefinedStruct pointStructLit = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 10)) + .addFields(ExpressionCreator.i32(false, 20)) + .build(); + + Expression.Literal.Builder lineBldr = Expression.Literal.newBuilder(); + Expression.Literal.Struct lineStruct = + Expression.Literal.Struct.newBuilder() + .addFields(pointBldr.build()) + .addFields(pointBldr.build()) + .build(); + Any lineAny = Any.pack(lineBldr.setStruct(lineStruct).build()); + UserDefinedLiteral lineAny1 = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "line", + java.util.Collections.emptyList(), + lineAny); + + io.substrait.expression.Expression.UserDefinedStruct startPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 50)) + .addFields(ExpressionCreator.i32(false, 60)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct endPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 70)) + .addFields(ExpressionCreator.i32(false, 80)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct lineStructLit = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) + .build(); + + Rel originalRel = + b.project( + input -> List.of(pointAny1, pointStructLit, lineAny1, lineStructLit), + b.remap(4), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + verifyProtoRoundTrip(originalRel); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/type/SubstraitUserDefinedTypeTest.java b/isthmus/src/test/java/io/substrait/isthmus/type/SubstraitUserDefinedTypeTest.java new file mode 100644 index 000000000..105d30091 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/type/SubstraitUserDefinedTypeTest.java @@ -0,0 +1,42 @@ +package io.substrait.isthmus.type; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +import io.substrait.proto.Type; +import java.util.List; +import org.junit.jupiter.api.Test; + +class SubstraitUserDefinedTypeTest { + private static final String URN = "extension:io.substrait:test"; + private static final String NAME = "custom"; + + @Test + void differentTypeParametersProduceDifferentDigests() { + Type.Parameter integerParam = Type.Parameter.newBuilder().setInteger(1).build(); + Type.Parameter enumParam = Type.Parameter.newBuilder().setEnum("value").build(); + + SubstraitUserDefinedType typeWithInteger = + new SubstraitUserDefinedType.SubstraitUserDefinedAnyType( + URN, NAME, List.of(integerParam), false); + SubstraitUserDefinedType typeWithEnum = + new SubstraitUserDefinedType.SubstraitUserDefinedAnyType( + URN, NAME, List.of(enumParam), false); + + assertNotEquals(typeWithInteger, typeWithEnum); + assertNotEquals(typeWithInteger.toString(), typeWithEnum.toString()); + } + + @Test + void sameParametersRemainEqual() { + Type.Parameter integerParam = Type.Parameter.newBuilder().setInteger(7).build(); + SubstraitUserDefinedType left = + new SubstraitUserDefinedType.SubstraitUserDefinedAnyType( + URN, NAME, List.of(integerParam), true); + SubstraitUserDefinedType right = + new SubstraitUserDefinedType.SubstraitUserDefinedAnyType( + URN, NAME, List.of(integerParam), true); + + assertEquals(left, right); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index 2c90f133d..d230f466d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -1,5 +1,6 @@ package io.substrait.isthmus.utils; +import io.substrait.isthmus.type.SubstraitUserDefinedType; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import org.apache.calcite.rel.type.RelDataType; @@ -35,11 +36,48 @@ public Type createSubstrait(boolean nullable) { return TypeCreator.of(nullable).userDefined(urn, name); } + /** + * Test-specific variant of the core implementation that treats Calcite-copied types as + * equivalent, even when they are not the same Java instance. Calcite often clones UDTs during + * planning, so reference equality alone would fail in tests. + */ public boolean isTypeFromFactory(RelDataType type) { - return type == N || type == R; + return matchesSubstraitType(type) || type == N || type == R || matchesCalciteAlias(type); } - private static class InnerType extends RelDataTypeImpl { + /** + * Detects Substrait-backed Calcite types by interrogating their metadata. + * + *

If Calcite preserves the original {@link SubstraitUserDefinedType}, the urn/name are both + * available directly. Otherwise, Calcite may create an anonymous {@link RelDataTypeImpl} copy, + * exposing only its alias string. In that case we fall back to comparing the formatted alias (see + * {@link InnerType#generateTypeString}). + */ + private boolean matchesSubstraitType(RelDataType type) { + if (type instanceof SubstraitUserDefinedType) { + SubstraitUserDefinedType udt = (SubstraitUserDefinedType) type; + return this.urn.equals(udt.getUrn()) && this.name.equals(udt.getName()); + } + return false; + } + + /** + * Calcite may copy a user-defined type into an anonymous {@link RelDataTypeImpl} where the only + * identifier left is its alias string. This helper captures the "find by alias" fallback so it’s + * clear we’re matching against the formatted urn::name when the rich metadata is not + * available. + */ + private boolean matchesCalciteAlias(RelDataType type) { + return type != null + && (type.getSqlTypeName() == SqlTypeName.OTHER || type.getSqlTypeName() == SqlTypeName.ROW) + && type.toString().equals(calciteDisplayName()); + } + + private String calciteDisplayName() { + return String.format("%s::%s", this.urn, this.name); + } + + private class InnerType extends RelDataTypeImpl { private final boolean nullable; private final String name; @@ -61,7 +99,7 @@ public SqlTypeName getSqlTypeName() { @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { - sb.append(name); + sb.append(UserTypeFactory.this.urn).append("::").append(name); } } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactoryTest.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactoryTest.java new file mode 100644 index 000000000..df1c2dcd0 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactoryTest.java @@ -0,0 +1,35 @@ +package io.substrait.isthmus.utils; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.isthmus.type.SubstraitUserDefinedType; +import io.substrait.type.Type; +import org.apache.calcite.rel.type.RelDataType; +import org.junit.jupiter.api.Test; + +class UserTypeFactoryTest { + + private static final String URN = "extension:io.substrait:test"; + private static final String NAME = "custom_type"; + + @Test + void detectsSubstraitUserDefinedType() { + UserTypeFactory factory = new UserTypeFactory(URN, NAME); + RelDataType substraitType = + SubstraitUserDefinedType.from( + Type.UserDefined.builder().nullable(true).urn(URN).name(NAME).build()); + + assertTrue(factory.isTypeFromFactory(substraitType)); + } + + @Test + void rejectsDifferentUrnOrName() { + UserTypeFactory factory = new UserTypeFactory(URN, NAME); + RelDataType differentType = + SubstraitUserDefinedType.from( + Type.UserDefined.builder().nullable(true).urn(URN).name("other").build()); + + assertFalse(factory.isTypeFromFactory(differentType)); + } +} diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 5377f4257..133052fa1 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -76,8 +76,12 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { s"${expr.declaration().key()}[${expr.outputType().accept(ToTypeString.INSTANCE)}]($args)" } + override def visit(expr: Expression.UserDefinedAny, context: EmptyVisitationContext): String = { + expr.toString + } + override def visit( - expr: Expression.UserDefinedLiteral, + expr: Expression.UserDefinedStruct, context: EmptyVisitationContext): String = { expr.toString } diff --git a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala index 5f7137b14..07594d3bf 100644 --- a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala +++ b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala @@ -65,9 +65,9 @@ class DefaultExpressionVisitor[T] context: EmptyVisitationContext): T = e.accept(this, context) - override def visit( - userDefinedLiteral: Expression.UserDefinedLiteral, - context: EmptyVisitationContext): T = { - visitFallback(userDefinedLiteral, context) - } + override def visit(expr: Expression.UserDefinedAny, context: EmptyVisitationContext): T = + visitFallback(expr, context) + + override def visit(expr: Expression.UserDefinedStruct, context: EmptyVisitationContext): T = + visitFallback(expr, context) }