@@ -30,12 +30,13 @@ import scala.util.Try
3030import com .google .protobuf .ByteString
3131
3232import org .apache .spark .connect .proto
33+ import org .apache .spark .sql .Row
3334import org .apache .spark .sql .catalyst .ScalaReflection
35+ import org .apache .spark .sql .catalyst .expressions .GenericRowWithSchema
3436import org .apache .spark .sql .catalyst .util .{SparkDateTimeUtils , SparkIntervalUtils }
3537import org .apache .spark .sql .connect .common .DataTypeProtoConverter ._
3638import org .apache .spark .sql .types ._
3739import org .apache .spark .unsafe .types .CalendarInterval
38- import org .apache .spark .util .SparkClassUtils
3940
4041object LiteralValueProtoConverter {
4142
@@ -223,52 +224,51 @@ object LiteralValueProtoConverter {
223224 val sb = builder.getStructBuilder
224225 val fields = structType.fields
225226
226- scalaValue match {
227+ val iter = scalaValue match {
227228 case p : Product =>
228- val iter = p.productIterator
229- var idx = 0
230- if (options.useDeprecatedDataTypeFields) {
231- while (idx < structType.size) {
232- val field = fields(idx)
233- // For backward compatibility, we need the data type for each field.
234- val literalProto = toLiteralProtoBuilderInternal(
235- iter.next(),
236- field.dataType,
237- options,
238- needDataType = true ).build()
239- sb.addElements(literalProto)
240- idx += 1
241- }
242- sb.setStructType(toConnectProtoType(structType))
243- } else {
244- while (idx < structType.size) {
245- val field = fields(idx)
246- val literalProto =
247- toLiteralProtoBuilderInternal(iter.next(), field.dataType, options, needDataType)
248- .build()
249- sb.addElements(literalProto)
250-
251- if (needDataType) {
252- val fieldBuilder = sb.getDataTypeStructBuilder
253- .addFieldsBuilder()
254- .setName(field.name)
255- .setNullable(field.nullable)
256-
257- if (LiteralValueProtoConverter .getInferredDataType(literalProto).isEmpty) {
258- fieldBuilder.setDataType(toConnectProtoType(field.dataType))
259- }
260-
261- // Set metadata if available
262- if (field.metadata != Metadata .empty) {
263- fieldBuilder.setMetadata(field.metadata.json)
264- }
265- }
229+ p.productIterator
230+ case r : Row =>
231+ r.toSeq.iterator
232+ case other =>
233+ throw new IllegalArgumentException (
234+ s " literal ${other.getClass.getName}( $other) not supported (yet). " )
235+ }
266236
267- idx += 1
268- }
237+ var idx = 0
238+ if (options.useDeprecatedDataTypeFields) {
239+ while (idx < structType.size) {
240+ val field = fields(idx)
241+ val literalProto =
242+ toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
243+ sb.addElements(literalProto)
244+ idx += 1
245+ }
246+ sb.setStructType(toConnectProtoType(structType))
247+ } else {
248+ val dataTypeStruct = proto.DataType .Struct .newBuilder()
249+ while (idx < structType.size) {
250+ val field = fields(idx)
251+ val literalProto =
252+ toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
253+ sb.addElements(literalProto)
254+
255+ val fieldBuilder = dataTypeStruct
256+ .addFieldsBuilder()
257+ .setName(field.name)
258+ .setNullable(field.nullable)
259+
260+ if (LiteralValueProtoConverter .getInferredDataType(literalProto).isEmpty) {
261+ fieldBuilder.setDataType(toConnectProtoType(field.dataType))
269262 }
270- case other =>
271- throw new IllegalArgumentException (s " literal $other not supported (yet). " )
263+
264+ // Set metadata if available
265+ if (field.metadata != Metadata .empty) {
266+ fieldBuilder.setMetadata(field.metadata.json)
267+ }
268+
269+ idx += 1
270+ }
271+ sb.setDataTypeStruct(dataTypeStruct.build())
272272 }
273273
274274 sb
@@ -721,23 +721,12 @@ object LiteralValueProtoConverter {
721721 private def toScalaStructInternal (
722722 struct : proto.Expression .Literal .Struct ,
723723 structType : proto.DataType .Struct ): Any = {
724- def toTuple [A <: Object ](data : Seq [A ]): Product = {
725- try {
726- val tupleClass = SparkClassUtils .classForName(s " scala.Tuple ${data.length}" )
727- tupleClass.getConstructors.head.newInstance(data : _* ).asInstanceOf [Product ]
728- } catch {
729- case _ : Exception =>
730- throw InvalidPlanInput (s " Unsupported Literal: ${data.mkString(" Array(" , " , " , " )" )}) " )
731- }
732- }
733-
734- val size = struct.getElementsCount
735- val structData = Seq .tabulate(size) { i =>
724+ val structData = Array .tabulate(struct.getElementsCount) { i =>
736725 val element = struct.getElements(i)
737726 val dataType = structType.getFields(i).getDataType
738- getConverter(dataType)(element). asInstanceOf [ Object ]
727+ getConverter(dataType)(element)
739728 }
740- toTuple (structData)
729+ new GenericRowWithSchema (structData, DataTypeProtoConverter .toCatalystStructType(structType) )
741730 }
742731
743732 def getProtoStructType (struct : proto.Expression .Literal .Struct ): proto.DataType .Struct = {
@@ -759,4 +748,77 @@ object LiteralValueProtoConverter {
759748 def toScalaStruct (struct : proto.Expression .Literal .Struct ): Any = {
760749 toScalaStructInternal(struct, getProtoStructType(struct))
761750 }
751+
752+ def getDataType (lit : proto.Expression .Literal ): DataType = {
753+ lit.getLiteralTypeCase match {
754+ case proto.Expression .Literal .LiteralTypeCase .NULL =>
755+ DataTypeProtoConverter .toCatalystType(lit.getNull)
756+ case proto.Expression .Literal .LiteralTypeCase .BINARY =>
757+ BinaryType
758+ case proto.Expression .Literal .LiteralTypeCase .BOOLEAN =>
759+ BooleanType
760+ case proto.Expression .Literal .LiteralTypeCase .BYTE =>
761+ ByteType
762+ case proto.Expression .Literal .LiteralTypeCase .SHORT =>
763+ ShortType
764+ case proto.Expression .Literal .LiteralTypeCase .INTEGER =>
765+ IntegerType
766+ case proto.Expression .Literal .LiteralTypeCase .LONG =>
767+ LongType
768+ case proto.Expression .Literal .LiteralTypeCase .FLOAT =>
769+ FloatType
770+ case proto.Expression .Literal .LiteralTypeCase .DOUBLE =>
771+ DoubleType
772+ case proto.Expression .Literal .LiteralTypeCase .DECIMAL =>
773+ val decimal = Decimal .apply(lit.getDecimal.getValue)
774+ var precision = decimal.precision
775+ if (lit.getDecimal.hasPrecision) {
776+ precision = math.max(precision, lit.getDecimal.getPrecision)
777+ }
778+ var scale = decimal.scale
779+ if (lit.getDecimal.hasScale) {
780+ scale = math.max(scale, lit.getDecimal.getScale)
781+ }
782+ DecimalType (math.max(precision, scale), scale)
783+ case proto.Expression .Literal .LiteralTypeCase .STRING =>
784+ StringType
785+ case proto.Expression .Literal .LiteralTypeCase .DATE =>
786+ DateType
787+ case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP =>
788+ TimestampType
789+ case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP_NTZ =>
790+ TimestampNTZType
791+ case proto.Expression .Literal .LiteralTypeCase .CALENDAR_INTERVAL =>
792+ CalendarIntervalType
793+ case proto.Expression .Literal .LiteralTypeCase .YEAR_MONTH_INTERVAL =>
794+ YearMonthIntervalType ()
795+ case proto.Expression .Literal .LiteralTypeCase .DAY_TIME_INTERVAL =>
796+ DayTimeIntervalType ()
797+ case proto.Expression .Literal .LiteralTypeCase .TIME =>
798+ var precision = TimeType .DEFAULT_PRECISION
799+ if (lit.getTime.hasPrecision) {
800+ precision = lit.getTime.getPrecision
801+ }
802+ TimeType (precision)
803+ case proto.Expression .Literal .LiteralTypeCase .ARRAY =>
804+ DataTypeProtoConverter .toCatalystType(
805+ proto.DataType .newBuilder
806+ .setArray(LiteralValueProtoConverter .getProtoArrayType(lit.getArray))
807+ .build())
808+ case proto.Expression .Literal .LiteralTypeCase .MAP =>
809+ DataTypeProtoConverter .toCatalystType(
810+ proto.DataType .newBuilder
811+ .setMap(LiteralValueProtoConverter .getProtoMapType(lit.getMap))
812+ .build())
813+ case proto.Expression .Literal .LiteralTypeCase .STRUCT =>
814+ DataTypeProtoConverter .toCatalystType(
815+ proto.DataType .newBuilder
816+ .setStruct(LiteralValueProtoConverter .getProtoStructType(lit.getStruct))
817+ .build())
818+ case _ =>
819+ throw InvalidPlanInput (
820+ s " Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
821+ s " ( ${lit.getLiteralTypeCase.getNumber}) " )
822+ }
823+ }
762824}
0 commit comments