@@ -30,12 +30,13 @@ import scala.util.Try
30
30
import com .google .protobuf .ByteString
31
31
32
32
import org .apache .spark .connect .proto
33
+ import org .apache .spark .sql .Row
33
34
import org .apache .spark .sql .catalyst .ScalaReflection
35
+ import org .apache .spark .sql .catalyst .expressions .GenericRowWithSchema
34
36
import org .apache .spark .sql .catalyst .util .{SparkDateTimeUtils , SparkIntervalUtils }
35
37
import org .apache .spark .sql .connect .common .DataTypeProtoConverter ._
36
38
import org .apache .spark .sql .types ._
37
39
import org .apache .spark .unsafe .types .CalendarInterval
38
- import org .apache .spark .util .SparkClassUtils
39
40
40
41
object LiteralValueProtoConverter {
41
42
@@ -182,47 +183,52 @@ object LiteralValueProtoConverter {
182
183
val sb = builder.getStructBuilder
183
184
val fields = structType.fields
184
185
185
- scalaValue match {
186
+ val iter = scalaValue match {
186
187
case p : Product =>
187
- val iter = p.productIterator
188
- var idx = 0
189
- if (options.useDeprecatedDataTypeFields) {
190
- while (idx < structType.size) {
191
- val field = fields(idx)
192
- val literalProto =
193
- toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
194
- sb.addElements(literalProto)
195
- idx += 1
196
- }
197
- sb.setStructType(toConnectProtoType(structType))
198
- } else {
199
- val dataTypeStruct = proto.DataType .Struct .newBuilder()
200
- while (idx < structType.size) {
201
- val field = fields(idx)
202
- val literalProto =
203
- toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
204
- sb.addElements(literalProto)
205
-
206
- val fieldBuilder = dataTypeStruct
207
- .addFieldsBuilder()
208
- .setName(field.name)
209
- .setNullable(field.nullable)
210
-
211
- if (LiteralValueProtoConverter .getInferredDataType(literalProto).isEmpty) {
212
- fieldBuilder.setDataType(toConnectProtoType(field.dataType))
213
- }
188
+ p.productIterator
189
+ case r : Row =>
190
+ r.toSeq.iterator
191
+ case other =>
192
+ throw new IllegalArgumentException (
193
+ s " literal ${other.getClass.getName}( $other) not supported (yet). " +
194
+ s " ${structType.catalogString}" )
195
+ }
214
196
215
- // Set metadata if available
216
- if (field.metadata != Metadata .empty) {
217
- fieldBuilder.setMetadata(field.metadata.json)
218
- }
197
+ var idx = 0
198
+ if (options.useDeprecatedDataTypeFields) {
199
+ while (idx < structType.size) {
200
+ val field = fields(idx)
201
+ val literalProto =
202
+ toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
203
+ sb.addElements(literalProto)
204
+ idx += 1
205
+ }
206
+ sb.setStructType(toConnectProtoType(structType))
207
+ } else {
208
+ val dataTypeStruct = proto.DataType .Struct .newBuilder()
209
+ while (idx < structType.size) {
210
+ val field = fields(idx)
211
+ val literalProto =
212
+ toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
213
+ sb.addElements(literalProto)
214
+
215
+ val fieldBuilder = dataTypeStruct
216
+ .addFieldsBuilder()
217
+ .setName(field.name)
218
+ .setNullable(field.nullable)
219
+
220
+ if (LiteralValueProtoConverter .getInferredDataType(literalProto).isEmpty) {
221
+ fieldBuilder.setDataType(toConnectProtoType(field.dataType))
222
+ }
219
223
220
- idx += 1
221
- }
222
- sb.setDataTypeStruct(dataTypeStruct.build() )
224
+ // Set metadata if available
225
+ if (field.metadata != Metadata .empty) {
226
+ fieldBuilder.setMetadata(field.metadata.json )
223
227
}
224
- case other =>
225
- throw new IllegalArgumentException (s " literal $other not supported (yet). " )
228
+
229
+ idx += 1
230
+ }
231
+ sb.setDataTypeStruct(dataTypeStruct.build())
226
232
}
227
233
228
234
sb
@@ -414,6 +420,9 @@ object LiteralValueProtoConverter {
414
420
case proto.Expression .Literal .LiteralTypeCase .ARRAY =>
415
421
toCatalystArray(literal.getArray)
416
422
423
+ case proto.Expression .Literal .LiteralTypeCase .MAP =>
424
+ toCatalystMap(literal.getMap)
425
+
417
426
case proto.Expression .Literal .LiteralTypeCase .STRUCT =>
418
427
toCatalystStruct(literal.getStruct)
419
428
@@ -668,23 +677,12 @@ object LiteralValueProtoConverter {
668
677
private def toCatalystStructInternal (
669
678
struct : proto.Expression .Literal .Struct ,
670
679
structType : proto.DataType .Struct ): Any = {
671
- def toTuple [A <: Object ](data : Seq [A ]): Product = {
672
- try {
673
- val tupleClass = SparkClassUtils .classForName(s " scala.Tuple ${data.length}" )
674
- tupleClass.getConstructors.head.newInstance(data : _* ).asInstanceOf [Product ]
675
- } catch {
676
- case _ : Exception =>
677
- throw InvalidPlanInput (s " Unsupported Literal: ${data.mkString(" Array(" , " , " , " )" )}) " )
678
- }
679
- }
680
-
681
- val size = struct.getElementsCount
682
- val structData = Seq .tabulate(size) { i =>
680
+ val structData = Array .tabulate(struct.getElementsCount) { i =>
683
681
val element = struct.getElements(i)
684
682
val dataType = structType.getFields(i).getDataType
685
- getConverter(dataType)(element). asInstanceOf [ Object ]
683
+ getConverter(dataType)(element)
686
684
}
687
- toTuple (structData)
685
+ new GenericRowWithSchema (structData, DataTypeProtoConverter .toCatalystStructType(structType) )
688
686
}
689
687
690
688
def getProtoStructType (struct : proto.Expression .Literal .Struct ): proto.DataType .Struct = {
@@ -706,4 +704,80 @@ object LiteralValueProtoConverter {
706
704
def toCatalystStruct (struct : proto.Expression .Literal .Struct ): Any = {
707
705
toCatalystStructInternal(struct, getProtoStructType(struct))
708
706
}
707
+
708
+ def getDataType (lit : proto.Expression .Literal ): DataType = {
709
+ lit.getLiteralTypeCase match {
710
+ case proto.Expression .Literal .LiteralTypeCase .NULL =>
711
+ DataTypeProtoConverter .toCatalystType(lit.getNull)
712
+ case proto.Expression .Literal .LiteralTypeCase .BINARY =>
713
+ BinaryType
714
+ case proto.Expression .Literal .LiteralTypeCase .BOOLEAN =>
715
+ BooleanType
716
+ case proto.Expression .Literal .LiteralTypeCase .BYTE =>
717
+ ByteType
718
+ case proto.Expression .Literal .LiteralTypeCase .SHORT =>
719
+ ShortType
720
+ case proto.Expression .Literal .LiteralTypeCase .INTEGER =>
721
+ IntegerType
722
+ case proto.Expression .Literal .LiteralTypeCase .LONG =>
723
+ LongType
724
+ case proto.Expression .Literal .LiteralTypeCase .FLOAT =>
725
+ FloatType
726
+ case proto.Expression .Literal .LiteralTypeCase .DOUBLE =>
727
+ DoubleType
728
+ case proto.Expression .Literal .LiteralTypeCase .DECIMAL =>
729
+ val decimal = Decimal .apply(lit.getDecimal.getValue)
730
+ var precision = decimal.precision
731
+ if (lit.getDecimal.hasPrecision) {
732
+ precision = math.max(precision, lit.getDecimal.getPrecision)
733
+ }
734
+ var scale = decimal.scale
735
+ if (lit.getDecimal.hasScale) {
736
+ scale = math.max(scale, lit.getDecimal.getScale)
737
+ }
738
+ DecimalType (math.max(precision, scale), scale)
739
+ case proto.Expression .Literal .LiteralTypeCase .STRING =>
740
+ StringType
741
+ case proto.Expression .Literal .LiteralTypeCase .DATE =>
742
+ DateType
743
+ case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP =>
744
+ TimestampType
745
+ case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP_NTZ =>
746
+ TimestampNTZType
747
+ case proto.Expression .Literal .LiteralTypeCase .CALENDAR_INTERVAL =>
748
+ CalendarIntervalType
749
+ case proto.Expression .Literal .LiteralTypeCase .YEAR_MONTH_INTERVAL =>
750
+ YearMonthIntervalType ()
751
+ case proto.Expression .Literal .LiteralTypeCase .DAY_TIME_INTERVAL =>
752
+ DayTimeIntervalType ()
753
+ case proto.Expression .Literal .LiteralTypeCase .TIME =>
754
+ var precision = TimeType .DEFAULT_PRECISION
755
+ if (lit.getTime.hasPrecision) {
756
+ precision = lit.getTime.getPrecision
757
+ }
758
+ TimeType (precision)
759
+ case proto.Expression .Literal .LiteralTypeCase .ARRAY =>
760
+ val arrayData = LiteralValueProtoConverter .toCatalystArray(lit.getArray)
761
+ DataTypeProtoConverter .toCatalystType(
762
+ proto.DataType .newBuilder
763
+ .setArray(LiteralValueProtoConverter .getProtoArrayType(lit.getArray))
764
+ .build())
765
+ case proto.Expression .Literal .LiteralTypeCase .MAP =>
766
+ LiteralValueProtoConverter .toCatalystMap(lit.getMap)
767
+ DataTypeProtoConverter .toCatalystType(
768
+ proto.DataType .newBuilder
769
+ .setMap(LiteralValueProtoConverter .getProtoMapType(lit.getMap))
770
+ .build())
771
+ case proto.Expression .Literal .LiteralTypeCase .STRUCT =>
772
+ LiteralValueProtoConverter .toCatalystStruct(lit.getStruct)
773
+ DataTypeProtoConverter .toCatalystType(
774
+ proto.DataType .newBuilder
775
+ .setStruct(LiteralValueProtoConverter .getProtoStructType(lit.getStruct))
776
+ .build())
777
+ case _ =>
778
+ throw InvalidPlanInput (
779
+ s " Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
780
+ s " ( ${lit.getLiteralTypeCase.getNumber}) " )
781
+ }
782
+ }
709
783
}
0 commit comments