Skip to content

Commit 6fb749d

Browse files
committed
[SPARK-53490] Fix Protobuf conversion in observed metrics
1 parent a74d50b commit 6fb749d

File tree

8 files changed

+236
-118
lines changed

8 files changed

+236
-118
lines changed

sql/api/src/main/scala/org/apache/spark/sql/Observation.scala

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ class Observation(val name: String) {
5858

5959
private val isRegistered = new AtomicBoolean()
6060

61-
private val promise = Promise[Map[String, Any]]()
61+
private val promise = Promise[Row]()
6262

6363
/**
6464
* Future holding the (yet to be completed) observation.
6565
*/
66-
val future: Future[Map[String, Any]] = promise.future
66+
val future: Future[Row] = promise.future
6767

6868
/**
6969
* (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish its
@@ -76,7 +76,10 @@ class Observation(val name: String) {
7676
* interrupted while waiting
7777
*/
7878
@throws[InterruptedException]
79-
def get: Map[String, Any] = SparkThreadUtils.awaitResult(future, Duration.Inf)
79+
def get: Map[String, Any] = {
80+
val row = SparkThreadUtils.awaitResult(future, Duration.Inf)
81+
row.getValuesMap(row.schema.map(_.name))
82+
}
8083

8184
/**
8285
* (Java-specific) Get the observed metrics. This waits for the observed dataset to finish its
@@ -99,7 +102,8 @@ class Observation(val name: String) {
99102
*/
100103
@throws[InterruptedException]
101104
private[sql] def getOrEmpty: Map[String, Any] = {
102-
Try(SparkThreadUtils.awaitResult(future, 100.millis)).getOrElse(Map.empty)
105+
val row = getRowOrEmpty.getOrElse(Row.empty)
106+
row.getValuesMap(row.schema.map(_.name))
103107
}
104108

105109
/**
@@ -118,8 +122,17 @@ class Observation(val name: String) {
118122
* `true` if all waiting threads were notified, `false` if otherwise.
119123
*/
120124
private[sql] def setMetricsAndNotify(metrics: Row): Boolean = {
121-
val metricsMap = metrics.getValuesMap(metrics.schema.map(_.name))
122-
promise.trySuccess(metricsMap)
125+
promise.trySuccess(metrics)
126+
}
127+
128+
/**
129+
* Get the observed metrics as a Row.
130+
*
131+
* @return
132+
* the observed metrics as a `Row`, or None if the metrics are not available.
133+
*/
134+
private[sql] def getRowOrEmpty: Option[Row] = {
135+
Try(SparkThreadUtils.awaitResult(future, 100.millis)).toOption
123136
}
124137
}
125138

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,6 +1716,42 @@ class ClientE2ETestSuite
17161716
schema.fields.head.dataType.asInstanceOf[MapType].valueContainsNull === valueContainsNull)
17171717
}
17181718
}
1719+
1720+
test("SPARK-53490: struct type in observed metrics") {
1721+
val observation = Observation("struct")
1722+
spark
1723+
.range(10)
1724+
.observe(observation, struct(count(lit(1)).as("rows"), max("id").as("maxid")).as("struct"))
1725+
.collect()
1726+
val expectedSchema =
1727+
StructType(Seq(StructField("rows", LongType), StructField("maxid", LongType)))
1728+
val expectedValue = new GenericRowWithSchema(Array(10, 9), expectedSchema)
1729+
assert(observation.get.size === 1)
1730+
assert(observation.get.contains("struct"))
1731+
assert(observation.get("struct") === expectedValue)
1732+
}
1733+
1734+
test("SPARK-53490: array type in observed metrics") {
1735+
val observation = Observation("array")
1736+
spark
1737+
.range(10)
1738+
.observe(observation, array(count(lit(1))).as("array"))
1739+
.collect()
1740+
assert(observation.get.size === 1)
1741+
assert(observation.get.contains("array"))
1742+
assert(observation.get("array") === Array(10))
1743+
}
1744+
1745+
test("SPARK-53490: map type in observed metrics") {
1746+
val observation = Observation("map")
1747+
spark
1748+
.range(10)
1749+
.observe(observation, map(lit("count"), count(lit(1))).as("map"))
1750+
.collect()
1751+
assert(observation.get.size === 1)
1752+
assert(observation.get.contains("map"))
1753+
assert(observation.get("map") === Map("count" -> 10))
1754+
}
17191755
}
17201756

17211757
private[sql] case class ClassData(a: String, b: Int)

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ private[sql] object SparkResult {
383383
(0 until metric.getKeysCount).foreach { i =>
384384
val key = metric.getKeys(i)
385385
val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
386-
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
386+
schema = schema.add(key, LiteralValueProtoConverter.getDataType(metric.getValues(i)))
387387
values += value
388388
}
389389
new GenericRowWithSchema(values.result(), schema)

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ object DataTypeProtoConverter {
109109
ArrayType(toCatalystType(t.getElementType), t.getContainsNull)
110110
}
111111

112-
private def toCatalystStructType(t: proto.DataType.Struct): StructType = {
112+
private[common] def toCatalystStructType(t: proto.DataType.Struct): StructType = {
113113
val fields = t.getFieldsList.asScala.toSeq.map { protoField =>
114114
val metadata = if (protoField.hasMetadata) {
115115
Metadata.fromJson(protoField.getMetadata)

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 126 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ import scala.util.Try
3030
import com.google.protobuf.ByteString
3131

3232
import org.apache.spark.connect.proto
33+
import org.apache.spark.sql.Row
3334
import org.apache.spark.sql.catalyst.ScalaReflection
35+
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
3436
import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
3537
import org.apache.spark.sql.connect.common.DataTypeProtoConverter._
3638
import org.apache.spark.sql.types._
3739
import org.apache.spark.unsafe.types.CalendarInterval
38-
import org.apache.spark.util.SparkClassUtils
3940

4041
object LiteralValueProtoConverter {
4142

@@ -182,47 +183,52 @@ object LiteralValueProtoConverter {
182183
val sb = builder.getStructBuilder
183184
val fields = structType.fields
184185

185-
scalaValue match {
186+
val iter = scalaValue match {
186187
case p: Product =>
187-
val iter = p.productIterator
188-
var idx = 0
189-
if (options.useDeprecatedDataTypeFields) {
190-
while (idx < structType.size) {
191-
val field = fields(idx)
192-
val literalProto =
193-
toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
194-
sb.addElements(literalProto)
195-
idx += 1
196-
}
197-
sb.setStructType(toConnectProtoType(structType))
198-
} else {
199-
val dataTypeStruct = proto.DataType.Struct.newBuilder()
200-
while (idx < structType.size) {
201-
val field = fields(idx)
202-
val literalProto =
203-
toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
204-
sb.addElements(literalProto)
205-
206-
val fieldBuilder = dataTypeStruct
207-
.addFieldsBuilder()
208-
.setName(field.name)
209-
.setNullable(field.nullable)
210-
211-
if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
212-
fieldBuilder.setDataType(toConnectProtoType(field.dataType))
213-
}
188+
p.productIterator
189+
case r: Row =>
190+
r.toSeq.iterator
191+
case other =>
192+
throw new IllegalArgumentException(
193+
s"literal ${other.getClass.getName}($other) not supported (yet)." +
194+
s" ${structType.catalogString}")
195+
}
214196

215-
// Set metadata if available
216-
if (field.metadata != Metadata.empty) {
217-
fieldBuilder.setMetadata(field.metadata.json)
218-
}
197+
var idx = 0
198+
if (options.useDeprecatedDataTypeFields) {
199+
while (idx < structType.size) {
200+
val field = fields(idx)
201+
val literalProto =
202+
toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
203+
sb.addElements(literalProto)
204+
idx += 1
205+
}
206+
sb.setStructType(toConnectProtoType(structType))
207+
} else {
208+
val dataTypeStruct = proto.DataType.Struct.newBuilder()
209+
while (idx < structType.size) {
210+
val field = fields(idx)
211+
val literalProto =
212+
toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
213+
sb.addElements(literalProto)
214+
215+
val fieldBuilder = dataTypeStruct
216+
.addFieldsBuilder()
217+
.setName(field.name)
218+
.setNullable(field.nullable)
219+
220+
if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
221+
fieldBuilder.setDataType(toConnectProtoType(field.dataType))
222+
}
219223

220-
idx += 1
221-
}
222-
sb.setDataTypeStruct(dataTypeStruct.build())
224+
// Set metadata if available
225+
if (field.metadata != Metadata.empty) {
226+
fieldBuilder.setMetadata(field.metadata.json)
223227
}
224-
case other =>
225-
throw new IllegalArgumentException(s"literal $other not supported (yet).")
228+
229+
idx += 1
230+
}
231+
sb.setDataTypeStruct(dataTypeStruct.build())
226232
}
227233

228234
sb
@@ -414,6 +420,9 @@ object LiteralValueProtoConverter {
414420
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
415421
toCatalystArray(literal.getArray)
416422

423+
case proto.Expression.Literal.LiteralTypeCase.MAP =>
424+
toCatalystMap(literal.getMap)
425+
417426
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
418427
toCatalystStruct(literal.getStruct)
419428

@@ -668,23 +677,12 @@ object LiteralValueProtoConverter {
668677
private def toCatalystStructInternal(
669678
struct: proto.Expression.Literal.Struct,
670679
structType: proto.DataType.Struct): Any = {
671-
def toTuple[A <: Object](data: Seq[A]): Product = {
672-
try {
673-
val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
674-
tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product]
675-
} catch {
676-
case _: Exception =>
677-
throw InvalidPlanInput(s"Unsupported Literal: ${data.mkString("Array(", ", ", ")")})")
678-
}
679-
}
680-
681-
val size = struct.getElementsCount
682-
val structData = Seq.tabulate(size) { i =>
680+
val structData = Array.tabulate(struct.getElementsCount) { i =>
683681
val element = struct.getElements(i)
684682
val dataType = structType.getFields(i).getDataType
685-
getConverter(dataType)(element).asInstanceOf[Object]
683+
getConverter(dataType)(element)
686684
}
687-
toTuple(structData)
685+
new GenericRowWithSchema(structData, DataTypeProtoConverter.toCatalystStructType(structType))
688686
}
689687

690688
def getProtoStructType(struct: proto.Expression.Literal.Struct): proto.DataType.Struct = {
@@ -706,4 +704,80 @@ object LiteralValueProtoConverter {
706704
def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = {
707705
toCatalystStructInternal(struct, getProtoStructType(struct))
708706
}
707+
708+
def getDataType(lit: proto.Expression.Literal): DataType = {
709+
lit.getLiteralTypeCase match {
710+
case proto.Expression.Literal.LiteralTypeCase.NULL =>
711+
DataTypeProtoConverter.toCatalystType(lit.getNull)
712+
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
713+
BinaryType
714+
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
715+
BooleanType
716+
case proto.Expression.Literal.LiteralTypeCase.BYTE =>
717+
ByteType
718+
case proto.Expression.Literal.LiteralTypeCase.SHORT =>
719+
ShortType
720+
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
721+
IntegerType
722+
case proto.Expression.Literal.LiteralTypeCase.LONG =>
723+
LongType
724+
case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
725+
FloatType
726+
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
727+
DoubleType
728+
case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
729+
val decimal = Decimal.apply(lit.getDecimal.getValue)
730+
var precision = decimal.precision
731+
if (lit.getDecimal.hasPrecision) {
732+
precision = math.max(precision, lit.getDecimal.getPrecision)
733+
}
734+
var scale = decimal.scale
735+
if (lit.getDecimal.hasScale) {
736+
scale = math.max(scale, lit.getDecimal.getScale)
737+
}
738+
DecimalType(math.max(precision, scale), scale)
739+
case proto.Expression.Literal.LiteralTypeCase.STRING =>
740+
StringType
741+
case proto.Expression.Literal.LiteralTypeCase.DATE =>
742+
DateType
743+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
744+
TimestampType
745+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
746+
TimestampNTZType
747+
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
748+
CalendarIntervalType
749+
case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
750+
YearMonthIntervalType()
751+
case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
752+
DayTimeIntervalType()
753+
case proto.Expression.Literal.LiteralTypeCase.TIME =>
754+
var precision = TimeType.DEFAULT_PRECISION
755+
if (lit.getTime.hasPrecision) {
756+
precision = lit.getTime.getPrecision
757+
}
758+
TimeType(precision)
759+
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
760+
val arrayData = LiteralValueProtoConverter.toCatalystArray(lit.getArray)
761+
DataTypeProtoConverter.toCatalystType(
762+
proto.DataType.newBuilder
763+
.setArray(LiteralValueProtoConverter.getProtoArrayType(lit.getArray))
764+
.build())
765+
case proto.Expression.Literal.LiteralTypeCase.MAP =>
766+
LiteralValueProtoConverter.toCatalystMap(lit.getMap)
767+
DataTypeProtoConverter.toCatalystType(
768+
proto.DataType.newBuilder
769+
.setMap(LiteralValueProtoConverter.getProtoMapType(lit.getMap))
770+
.build())
771+
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
772+
LiteralValueProtoConverter.toCatalystStruct(lit.getStruct)
773+
DataTypeProtoConverter.toCatalystType(
774+
proto.DataType.newBuilder
775+
.setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct))
776+
.build())
777+
case _ =>
778+
throw InvalidPlanInput(
779+
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
780+
s"(${lit.getLiteralTypeCase.getNumber})")
781+
}
782+
}
709783
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils
3131
import org.apache.spark.sql.connect.planner.InvalidInputErrors
3232
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService}
3333
import org.apache.spark.sql.connect.utils.ErrorUtils
34+
import org.apache.spark.sql.types.DataType
3435
import org.apache.spark.util.Utils
3536

3637
/**
@@ -227,21 +228,22 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
227228
executeHolder.request.getPlan.getDescriptorForType)
228229
}
229230

230-
val observedMetrics: Map[String, Seq[(Option[String], Any)]] = {
231+
val observedMetrics: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
231232
executeHolder.observations.map { case (name, observation) =>
232-
val values = observation.getOrEmpty.map { case (key, value) =>
233-
(Some(key), value)
234-
}.toSeq
233+
val values =
234+
observation.getRowOrEmpty
235+
.map(SparkConnectPlanExecution.toObservedMetricsValues(_))
236+
.getOrElse(Seq.empty)
235237
name -> values
236238
}.toMap
237239
}
238-
val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = {
240+
val accumulatedInPython: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
239241
executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator =>
240242
accumulator.synchronized {
241243
val value = accumulator.value.asScala.toSeq
242244
if (value.nonEmpty) {
243245
accumulator.reset()
244-
Some("__python_accumulator__" -> value.map(value => (None, value)))
246+
Some("__python_accumulator__" -> value.map(value => (None, value, None)))
245247
} else {
246248
None
247249
}

0 commit comments

Comments
 (0)