@@ -30,12 +30,13 @@ import scala.util.Try
3030import  com .google .protobuf .ByteString 
3131
3232import  org .apache .spark .connect .proto 
33+ import  org .apache .spark .sql .Row 
3334import  org .apache .spark .sql .catalyst .ScalaReflection 
35+ import  org .apache .spark .sql .catalyst .expressions .GenericRowWithSchema 
3436import  org .apache .spark .sql .catalyst .util .{SparkDateTimeUtils , SparkIntervalUtils }
3537import  org .apache .spark .sql .connect .common .DataTypeProtoConverter ._ 
3638import  org .apache .spark .sql .types ._ 
3739import  org .apache .spark .unsafe .types .CalendarInterval 
38- import  org .apache .spark .util .SparkClassUtils 
3940
4041object  LiteralValueProtoConverter  {
4142
@@ -182,47 +183,52 @@ object LiteralValueProtoConverter {
182183      val  sb  =  builder.getStructBuilder
183184      val  fields  =  structType.fields
184185
185-       scalaValue match  {
186+       val   iter   =   scalaValue match  {
186187        case  p : Product  => 
187-           val  iter  =  p.productIterator
188-           var  idx  =  0 
189-           if  (options.useDeprecatedDataTypeFields) {
190-             while  (idx <  structType.size) {
191-               val  field  =  fields(idx)
192-               val  literalProto  = 
193-                 toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
194-               sb.addElements(literalProto)
195-               idx +=  1 
196-             }
197-             sb.setStructType(toConnectProtoType(structType))
198-           } else  {
199-             val  dataTypeStruct  =  proto.DataType .Struct .newBuilder()
200-             while  (idx <  structType.size) {
201-               val  field  =  fields(idx)
202-               val  literalProto  = 
203-                 toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
204-               sb.addElements(literalProto)
205- 
206-               val  fieldBuilder  =  dataTypeStruct
207-                 .addFieldsBuilder()
208-                 .setName(field.name)
209-                 .setNullable(field.nullable)
210- 
211-               if  (LiteralValueProtoConverter .getInferredDataType(literalProto).isEmpty) {
212-                 fieldBuilder.setDataType(toConnectProtoType(field.dataType))
213-               }
188+           p.productIterator
189+         case  r : Row  => 
190+           r.toSeq.iterator
191+         case  other => 
192+           throw  new  IllegalArgumentException (
193+             s " literal  ${other.getClass.getName}( $other) not supported (yet). "  + 
194+               s "   ${structType.catalogString}" )
195+       }
214196
215-               //  Set metadata if available
216-               if  (field.metadata !=  Metadata .empty) {
217-                 fieldBuilder.setMetadata(field.metadata.json)
218-               }
197+       var  idx  =  0 
198+       if  (options.useDeprecatedDataTypeFields) {
199+         while  (idx <  structType.size) {
200+           val  field  =  fields(idx)
201+           val  literalProto  = 
202+             toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
203+           sb.addElements(literalProto)
204+           idx +=  1 
205+         }
206+         sb.setStructType(toConnectProtoType(structType))
207+       } else  {
208+         val  dataTypeStruct  =  proto.DataType .Struct .newBuilder()
209+         while  (idx <  structType.size) {
210+           val  field  =  fields(idx)
211+           val  literalProto  = 
212+             toLiteralProtoWithOptions(iter.next(), Some (field.dataType), options)
213+           sb.addElements(literalProto)
214+ 
215+           val  fieldBuilder  =  dataTypeStruct
216+             .addFieldsBuilder()
217+             .setName(field.name)
218+             .setNullable(field.nullable)
219+ 
220+           if  (LiteralValueProtoConverter .getInferredDataType(literalProto).isEmpty) {
221+             fieldBuilder.setDataType(toConnectProtoType(field.dataType))
222+           }
219223
220-               idx  +=   1 
221-             } 
222-             sb.setDataTypeStruct(dataTypeStruct.build() )
224+           //  Set metadata if available 
225+           if  (field.metadata  !=   Metadata .empty) { 
226+             fieldBuilder.setMetadata(field.metadata.json )
223227          }
224-         case  other => 
225-           throw  new  IllegalArgumentException (s " literal  $other not supported (yet). " )
228+ 
229+           idx +=  1 
230+         }
231+         sb.setDataTypeStruct(dataTypeStruct.build())
226232      }
227233
228234      sb
@@ -668,23 +674,13 @@ object LiteralValueProtoConverter {
668674  private  def  toCatalystStructInternal (
669675      struct : proto.Expression .Literal .Struct ,
670676      structType : proto.DataType .Struct ):  Any  =  {
671-     def  toTuple [A  <:  Object ](data : Seq [A ]):  Product  =  {
672-       try  {
673-         val  tupleClass  =  SparkClassUtils .classForName(s " scala.Tuple ${data.length}" )
674-         tupleClass.getConstructors.head.newInstance(data : _* ).asInstanceOf [Product ]
675-       } catch  {
676-         case  _ : Exception  => 
677-           throw  InvalidPlanInput (s " Unsupported Literal:  ${data.mkString(" Array("  , " , "  , " )"  )}) " )
678-       }
679-     }
680- 
681677    val  size  =  struct.getElementsCount
682-     val  structData  =  Seq .tabulate(size) { i => 
678+     val  structData  =  Array .tabulate(size) { i => 
683679      val  element  =  struct.getElements(i)
684680      val  dataType  =  structType.getFields(i).getDataType
685-       getConverter(dataType)(element). asInstanceOf [ Object ] 
681+       getConverter(dataType)(element)
686682    }
687-     toTuple (structData)
683+     new   GenericRowWithSchema (structData,  DataTypeProtoConverter .toCatalystStructType(structType) )
688684  }
689685
690686  def  getProtoStructType (struct : proto.Expression .Literal .Struct ):  proto.DataType .Struct  =  {
@@ -706,4 +702,80 @@ object LiteralValueProtoConverter {
706702  def  toCatalystStruct (struct : proto.Expression .Literal .Struct ):  Any  =  {
707703    toCatalystStructInternal(struct, getProtoStructType(struct))
708704  }
705+ 
706+   def  getDataType (lit : proto.Expression .Literal ):  DataType  =  {
707+     lit.getLiteralTypeCase match  {
708+       case  proto.Expression .Literal .LiteralTypeCase .NULL  => 
709+         DataTypeProtoConverter .toCatalystType(lit.getNull)
710+       case  proto.Expression .Literal .LiteralTypeCase .BINARY  => 
711+         BinaryType 
712+       case  proto.Expression .Literal .LiteralTypeCase .BOOLEAN  => 
713+         BooleanType 
714+       case  proto.Expression .Literal .LiteralTypeCase .BYTE  => 
715+         ByteType 
716+       case  proto.Expression .Literal .LiteralTypeCase .SHORT  => 
717+         ShortType 
718+       case  proto.Expression .Literal .LiteralTypeCase .INTEGER  => 
719+         IntegerType 
720+       case  proto.Expression .Literal .LiteralTypeCase .LONG  => 
721+         LongType 
722+       case  proto.Expression .Literal .LiteralTypeCase .FLOAT  => 
723+         FloatType 
724+       case  proto.Expression .Literal .LiteralTypeCase .DOUBLE  => 
725+         DoubleType 
726+       case  proto.Expression .Literal .LiteralTypeCase .DECIMAL  => 
727+         val  decimal  =  Decimal .apply(lit.getDecimal.getValue)
728+         var  precision  =  decimal.precision
729+         if  (lit.getDecimal.hasPrecision) {
730+           precision =  math.max(precision, lit.getDecimal.getPrecision)
731+         }
732+         var  scale  =  decimal.scale
733+         if  (lit.getDecimal.hasScale) {
734+           scale =  math.max(scale, lit.getDecimal.getScale)
735+         }
736+         DecimalType (math.max(precision, scale), scale)
737+       case  proto.Expression .Literal .LiteralTypeCase .STRING  => 
738+         StringType 
739+       case  proto.Expression .Literal .LiteralTypeCase .DATE  => 
740+         DateType 
741+       case  proto.Expression .Literal .LiteralTypeCase .TIMESTAMP  => 
742+         TimestampType 
743+       case  proto.Expression .Literal .LiteralTypeCase .TIMESTAMP_NTZ  => 
744+         TimestampNTZType 
745+       case  proto.Expression .Literal .LiteralTypeCase .CALENDAR_INTERVAL  => 
746+         CalendarIntervalType 
747+       case  proto.Expression .Literal .LiteralTypeCase .YEAR_MONTH_INTERVAL  => 
748+         YearMonthIntervalType ()
749+       case  proto.Expression .Literal .LiteralTypeCase .DAY_TIME_INTERVAL  => 
750+         DayTimeIntervalType ()
751+       case  proto.Expression .Literal .LiteralTypeCase .TIME  => 
752+         var  precision  =  TimeType .DEFAULT_PRECISION 
753+         if  (lit.getTime.hasPrecision) {
754+           precision =  lit.getTime.getPrecision
755+         }
756+         TimeType (precision)
757+       case  proto.Expression .Literal .LiteralTypeCase .ARRAY  => 
758+         val  arrayData  =  LiteralValueProtoConverter .toCatalystArray(lit.getArray)
759+         DataTypeProtoConverter .toCatalystType(
760+           proto.DataType .newBuilder
761+             .setArray(LiteralValueProtoConverter .getProtoArrayType(lit.getArray))
762+             .build())
763+       case  proto.Expression .Literal .LiteralTypeCase .MAP  => 
764+         LiteralValueProtoConverter .toCatalystMap(lit.getMap)
765+         DataTypeProtoConverter .toCatalystType(
766+           proto.DataType .newBuilder
767+             .setMap(LiteralValueProtoConverter .getProtoMapType(lit.getMap))
768+             .build())
769+       case  proto.Expression .Literal .LiteralTypeCase .STRUCT  => 
770+         LiteralValueProtoConverter .toCatalystStruct(lit.getStruct)
771+         DataTypeProtoConverter .toCatalystType(
772+           proto.DataType .newBuilder
773+             .setStruct(LiteralValueProtoConverter .getProtoStructType(lit.getStruct))
774+             .build())
775+       case  _ => 
776+         throw  InvalidPlanInput (
777+           s " Unsupported Literal Type:  ${lit.getLiteralTypeCase.name}"  + 
778+             s " ( ${lit.getLiteralTypeCase.getNumber}) " )
779+     }
780+   }
709781}
0 commit comments