Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][SPARK-49690][SQL] UDT type not showing up in sql type representation in the schema #48174

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ trait AgnosticEncoder[T] extends Encoder[T] {
def isPrimitive: Boolean
def nullable: Boolean = !isPrimitive
def dataType: DataType
override def schema: StructType = StructType(StructField("value", dataType, nullable) :: Nil)
override def schema: StructType = StructType(
StructField("value", DataType.udtToSqlType(dataType), nullable) :: Nil)
def lenientSerialization: Boolean = false
def isStruct: Boolean = false
}
Expand Down Expand Up @@ -108,7 +109,8 @@ object AgnosticEncoders {
metadata: Metadata,
readMethod: Option[String] = None,
writeMethod: Option[String] = None) {
def structField: StructField = StructField(name, enc.dataType, nullable, metadata)
def structField: StructField =
StructField(name, DataType.udtToSqlType(enc.dataType), nullable, metadata)
}

// Contains a sequence of fields.
Expand Down Expand Up @@ -178,6 +180,9 @@ object AgnosticEncoders {
override def isPrimitive: Boolean = false
override def dataType: DataType = udt
override def clsTag: ClassTag[E] = ClassTag(udt.userClass)

override def schema: StructType = StructType(
StructField("value", udt.sqlType, nullable) :: Nil)
}

// Enums are special leafs because we need to capture the class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,4 +570,10 @@ object DataType {
case (fromDataType, toDataType) => fromDataType == toDataType
}
}

def udtToSqlType(dataType: DataType): DataType = dataType match {
case udt: UserDefinedType[_] => udt.sqlType

case _ => dataType
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// ResolveLateralColumnAliasReference for more details.
ResolveLateralColumnAliasReference ::
ResolveExpressionsWithNamePlaceholders ::
FixUDTDeserializerUpcast ::
ResolveDeserializer ::
ResolveNewInstance ::
ResolveUpCast ::
Expand Down Expand Up @@ -3748,6 +3749,25 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
project
}

/**
* If the incoming child output does not contain any UDT type (which will be the case when
* plan is serialized as Row), then the Upcast target which is of UDT Type needs to be
* converted into sql type.
*/
object FixUDTDeserializerUpcast extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(DESERIALIZE_TO_OBJECT), UnknownRuleId) {
case d@DeserializeToObject(deserializer, _, child) if !deserializer.resolved &&
child.resolved && child.output.forall(
x => !classOf[UserDefinedType[_]].isAssignableFrom(x.dataType.getClass)) =>
val newDeserializer = deserializer.transformUpWithPruning(_.containsPattern(UP_CAST)) {
case UpCast(child, target: UserDefinedType[_], walkedPath)
=> UpCast(child, target.sqlType, walkedPath)
}
d.copy(deserializer = newDeserializer)
}
}

/**
* Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
* to the given input attributes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.{ReassignLambdaVariableID, Simpli
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LeafNode, LocalRelation}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{DataType, ObjectType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -211,7 +211,7 @@ case class ExpressionEncoder[T](
// The schema after converting `T` to a Spark SQL row. This schema is dependent on the given
// serializer.
val schema: StructType = StructType(serializer.map { s =>
StructField(s.name, s.dataType, s.nullable)
StructField(s.name, DataType.udtToSqlType(s.dataType), s.nullable)
})

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.{SPARK_DOC_ROOT, SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException}
import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData, ScroogeLikeExample}
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, JavaTypeInference, OptionalData, PrimitiveData, ScroogeLikeExample, UDTBean}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, TransformingEncoder}
Expand Down Expand Up @@ -430,6 +430,12 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes

productTest(("UDT", new ExamplePoint(0.1, 0.2)))

test("SPARK-49690: UDT schema not showing the sql type") {
val encoder1 = JavaTypeInference.encoderFor(classOf[UDTBean])
val expectedSchema = StructType(Seq(StructField("udt", StringType, true)))
assert(encoder1.schema === expectedSchema)
}

test("AnyVal class with Any fields") {
val exception = intercept[SparkUnsupportedOperationException](
implicitly[ExpressionEncoder[Foo]])
Expand Down