Skip to content

Commit 6dd4001

Browse files
heyihongzhengruifeng
authored andcommitted
[SPARK-53490][CONNECT][SQL] Fix Protobuf conversion in observed metrics
### What changes were proposed in this pull request? This PR fixes a critical issue in the protobuf conversion of observed metrics in Spark Connect, specifically when dealing with complex data types like structs, arrays, and maps. The main changes include: 1. **Modified Observation class to store Row objects instead of Map[String, Any]**: Changed the internal promise type from `Promise[Map[String, Any]]` to `Promise[Row]` to preserve type information during protobuf serialization/deserialization. 2. **Enhanced protobuf conversion for complex types**: - Added proper handling for struct types by creating `GenericRowWithSchema` objects instead of tuples - Added support for map type conversion in `LiteralValueProtoConverter` - Improved data type inference with a new `getDataType` method that properly handles all literal types 3. **Fixed observed metrics**: Modified the observed metrics processing to include data type information in the protobuf conversion, ensuring that complex types are properly serialized and deserialized. ### Why are the changes needed? The previous implementation had several issues: 1. **Data type loss**: Observed metrics were losing their original data types during Protobuf conversion, causing errors 2. **Struct handling problems**: The conversion logic didn't properly handle Row objects and struct types ### Does this PR introduce _any_ user-facing change? Yes, this PR fixes a bug that was preventing users from successfully using observed metrics with complex data types (structs, arrays, maps) in Spark Connect. Users can now: - Use `struct()` expressions in observed metrics and receive properly typed `GenericRowWithSchema` objects - Use `array()` expressions in observed metrics and receive properly typed arrays - Use `map()` expressions in observed metrics and receive properly typed maps Previously, the code below would fail. ```scala val observation = Observation("struct") spark .range(10) .observe(observation, struct(count(lit(1)).as("rows"), max("id").as("maxid")).as("struct")) .collect() observation.get // Below is the error message: """ org.apache.spark.SparkUnsupportedOperationException: literal [10,9] not supported (yet). org.apache.spark.sql.connect.common.LiteralValueProtoConverter$.toLiteralProtoBuilder(LiteralValueProtoConverter.scala:104) org.apache.spark.sql.connect.common.LiteralValueProtoConverter$.toLiteralProto(LiteralValueProtoConverter.scala:203) org.apache.spark.sql.connect.execution.SparkConnectPlanExecution$.$anonfun$createObservedMetricsResponse$2(SparkConnectPlanExecution.scala:571) org.apache.spark.sql.connect.execution.SparkConnectPlanExecution$.$anonfun$createObservedMetricsResponse$2$adapted(SparkConnectPlanExecution.scala:570) """ ``` ### How was this patch tested? `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite -- -z SPARK-53490"` `build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"` ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.5.9 Closes #52236 from heyihong/SPARK-53490. Authored-by: Yihong He <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent feaf659 commit 6dd4001

File tree

9 files changed

+245
-137
lines changed

9 files changed

+245
-137
lines changed

sql/api/src/main/scala/org/apache/spark/sql/Observation.scala

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ class Observation(val name: String) {
5858

5959
private val isRegistered = new AtomicBoolean()
6060

61-
private val promise = Promise[Map[String, Any]]()
61+
private val promise = Promise[Row]()
6262

6363
/**
6464
* Future holding the (yet to be completed) observation.
6565
*/
66-
val future: Future[Map[String, Any]] = promise.future
66+
val future: Future[Row] = promise.future
6767

6868
/**
6969
* (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish its
@@ -76,7 +76,10 @@ class Observation(val name: String) {
7676
* interrupted while waiting
7777
*/
7878
@throws[InterruptedException]
79-
def get: Map[String, Any] = SparkThreadUtils.awaitResult(future, Duration.Inf)
79+
def get: Map[String, Any] = {
80+
val row = SparkThreadUtils.awaitResult(future, Duration.Inf)
81+
row.getValuesMap(row.schema.map(_.name))
82+
}
8083

8184
/**
8285
* (Java-specific) Get the observed metrics. This waits for the observed dataset to finish its
@@ -99,7 +102,8 @@ class Observation(val name: String) {
99102
*/
100103
@throws[InterruptedException]
101104
private[sql] def getOrEmpty: Map[String, Any] = {
102-
Try(SparkThreadUtils.awaitResult(future, 100.millis)).getOrElse(Map.empty)
105+
val row = getRowOrEmpty.getOrElse(Row.empty)
106+
row.getValuesMap(row.schema.map(_.name))
103107
}
104108

105109
/**
@@ -118,8 +122,17 @@ class Observation(val name: String) {
118122
* `true` if all waiting threads were notified, `false` if otherwise.
119123
*/
120124
private[sql] def setMetricsAndNotify(metrics: Row): Boolean = {
121-
val metricsMap = metrics.getValuesMap(metrics.schema.map(_.name))
122-
promise.trySuccess(metricsMap)
125+
promise.trySuccess(metrics)
126+
}
127+
128+
/**
129+
* Get the observed metrics as a Row.
130+
*
131+
* @return
132+
* the observed metrics as a `Row`, or None if the metrics are not available.
133+
*/
134+
private[sql] def getRowOrEmpty: Option[Row] = {
135+
Try(SparkThreadUtils.awaitResult(future, 100.millis)).toOption
123136
}
124137
}
125138

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,42 @@ class ClientE2ETestSuite
17491749
val nullRows = nullResult.filter(_.getAs[Long]("id") >= 5)
17501750
assert(nullRows.forall(_.getAs[Int]("actual_p_id") == 0))
17511751
}
1752+
1753+
test("SPARK-53490: struct type in observed metrics") {
1754+
val observation = Observation("struct")
1755+
spark
1756+
.range(10)
1757+
.observe(observation, struct(count(lit(1)).as("rows"), max("id").as("maxid")).as("struct"))
1758+
.collect()
1759+
val expectedSchema =
1760+
StructType(Seq(StructField("rows", LongType), StructField("maxid", LongType)))
1761+
val expectedValue = new GenericRowWithSchema(Array(10, 9), expectedSchema)
1762+
assert(observation.get.size === 1)
1763+
assert(observation.get.contains("struct"))
1764+
assert(observation.get("struct") === expectedValue)
1765+
}
1766+
1767+
test("SPARK-53490: array type in observed metrics") {
1768+
val observation = Observation("array")
1769+
spark
1770+
.range(10)
1771+
.observe(observation, array(count(lit(1))).as("array"))
1772+
.collect()
1773+
assert(observation.get.size === 1)
1774+
assert(observation.get.contains("array"))
1775+
assert(observation.get("array") === Array(10))
1776+
}
1777+
1778+
test("SPARK-53490: map type in observed metrics") {
1779+
val observation = Observation("map")
1780+
spark
1781+
.range(10)
1782+
.observe(observation, map(lit("count"), count(lit(1))).as("map"))
1783+
.collect()
1784+
assert(observation.get.size === 1)
1785+
assert(observation.get.contains("map"))
1786+
assert(observation.get("map") === Map("count" -> 10))
1787+
}
17521788
}
17531789

17541790
private[sql] case class ClassData(a: String, b: Int)

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ private[sql] object SparkResult {
383383
(0 until metric.getKeysCount).foreach { i =>
384384
val key = metric.getKeys(i)
385385
val value = LiteralValueProtoConverter.toScalaValue(metric.getValues(i))
386-
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
386+
schema = schema.add(key, LiteralValueProtoConverter.getDataType(metric.getValues(i)))
387387
values += value
388388
}
389389
new GenericRowWithSchema(values.result(), schema)

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ object DataTypeProtoConverter {
109109
ArrayType(toCatalystType(t.getElementType), t.getContainsNull)
110110
}
111111

112-
private def toCatalystStructType(t: proto.DataType.Struct): StructType = {
112+
private[common] def toCatalystStructType(t: proto.DataType.Struct): StructType = {
113113
val fields = t.getFieldsList.asScala.toSeq.map { protoField =>
114114
val metadata = if (protoField.hasMetadata) {
115115
Metadata.fromJson(protoField.getMetadata)

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 120 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ import scala.util.Try
3030
import com.google.protobuf.ByteString
3131

3232
import org.apache.spark.connect.proto
33+
import org.apache.spark.sql.Row
3334
import org.apache.spark.sql.catalyst.ScalaReflection
35+
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
3436
import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
3537
import org.apache.spark.sql.connect.common.DataTypeProtoConverter._
3638
import org.apache.spark.sql.types._
3739
import org.apache.spark.unsafe.types.CalendarInterval
38-
import org.apache.spark.util.SparkClassUtils
3940

4041
object LiteralValueProtoConverter {
4142

@@ -223,52 +224,51 @@ object LiteralValueProtoConverter {
223224
val sb = builder.getStructBuilder
224225
val fields = structType.fields
225226

226-
scalaValue match {
227+
val iter = scalaValue match {
227228
case p: Product =>
228-
val iter = p.productIterator
229-
var idx = 0
230-
if (options.useDeprecatedDataTypeFields) {
231-
while (idx < structType.size) {
232-
val field = fields(idx)
233-
// For backward compatibility, we need the data type for each field.
234-
val literalProto = toLiteralProtoBuilderInternal(
235-
iter.next(),
236-
field.dataType,
237-
options,
238-
needDataType = true).build()
239-
sb.addElements(literalProto)
240-
idx += 1
241-
}
242-
sb.setStructType(toConnectProtoType(structType))
243-
} else {
244-
while (idx < structType.size) {
245-
val field = fields(idx)
246-
val literalProto =
247-
toLiteralProtoBuilderInternal(iter.next(), field.dataType, options, needDataType)
248-
.build()
249-
sb.addElements(literalProto)
250-
251-
if (needDataType) {
252-
val fieldBuilder = sb.getDataTypeStructBuilder
253-
.addFieldsBuilder()
254-
.setName(field.name)
255-
.setNullable(field.nullable)
256-
257-
if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
258-
fieldBuilder.setDataType(toConnectProtoType(field.dataType))
259-
}
260-
261-
// Set metadata if available
262-
if (field.metadata != Metadata.empty) {
263-
fieldBuilder.setMetadata(field.metadata.json)
264-
}
265-
}
229+
p.productIterator
230+
case r: Row =>
231+
r.toSeq.iterator
232+
case other =>
233+
throw new IllegalArgumentException(
234+
s"literal ${other.getClass.getName}($other) not supported (yet).")
235+
}
266236

267-
idx += 1
268-
}
237+
var idx = 0
238+
if (options.useDeprecatedDataTypeFields) {
239+
while (idx < structType.size) {
240+
val field = fields(idx)
241+
val literalProto =
242+
toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
243+
sb.addElements(literalProto)
244+
idx += 1
245+
}
246+
sb.setStructType(toConnectProtoType(structType))
247+
} else {
248+
val dataTypeStruct = proto.DataType.Struct.newBuilder()
249+
while (idx < structType.size) {
250+
val field = fields(idx)
251+
val literalProto =
252+
toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
253+
sb.addElements(literalProto)
254+
255+
val fieldBuilder = dataTypeStruct
256+
.addFieldsBuilder()
257+
.setName(field.name)
258+
.setNullable(field.nullable)
259+
260+
if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
261+
fieldBuilder.setDataType(toConnectProtoType(field.dataType))
269262
}
270-
case other =>
271-
throw new IllegalArgumentException(s"literal $other not supported (yet).")
263+
264+
// Set metadata if available
265+
if (field.metadata != Metadata.empty) {
266+
fieldBuilder.setMetadata(field.metadata.json)
267+
}
268+
269+
idx += 1
270+
}
271+
sb.setDataTypeStruct(dataTypeStruct.build())
272272
}
273273

274274
sb
@@ -721,23 +721,12 @@ object LiteralValueProtoConverter {
721721
private def toScalaStructInternal(
722722
struct: proto.Expression.Literal.Struct,
723723
structType: proto.DataType.Struct): Any = {
724-
def toTuple[A <: Object](data: Seq[A]): Product = {
725-
try {
726-
val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
727-
tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product]
728-
} catch {
729-
case _: Exception =>
730-
throw InvalidPlanInput(s"Unsupported Literal: ${data.mkString("Array(", ", ", ")")})")
731-
}
732-
}
733-
734-
val size = struct.getElementsCount
735-
val structData = Seq.tabulate(size) { i =>
724+
val structData = Array.tabulate(struct.getElementsCount) { i =>
736725
val element = struct.getElements(i)
737726
val dataType = structType.getFields(i).getDataType
738-
getConverter(dataType)(element).asInstanceOf[Object]
727+
getConverter(dataType)(element)
739728
}
740-
toTuple(structData)
729+
new GenericRowWithSchema(structData, DataTypeProtoConverter.toCatalystStructType(structType))
741730
}
742731

743732
def getProtoStructType(struct: proto.Expression.Literal.Struct): proto.DataType.Struct = {
@@ -759,4 +748,77 @@ object LiteralValueProtoConverter {
759748
def toScalaStruct(struct: proto.Expression.Literal.Struct): Any = {
760749
toScalaStructInternal(struct, getProtoStructType(struct))
761750
}
751+
752+
def getDataType(lit: proto.Expression.Literal): DataType = {
753+
lit.getLiteralTypeCase match {
754+
case proto.Expression.Literal.LiteralTypeCase.NULL =>
755+
DataTypeProtoConverter.toCatalystType(lit.getNull)
756+
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
757+
BinaryType
758+
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
759+
BooleanType
760+
case proto.Expression.Literal.LiteralTypeCase.BYTE =>
761+
ByteType
762+
case proto.Expression.Literal.LiteralTypeCase.SHORT =>
763+
ShortType
764+
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
765+
IntegerType
766+
case proto.Expression.Literal.LiteralTypeCase.LONG =>
767+
LongType
768+
case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
769+
FloatType
770+
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
771+
DoubleType
772+
case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
773+
val decimal = Decimal.apply(lit.getDecimal.getValue)
774+
var precision = decimal.precision
775+
if (lit.getDecimal.hasPrecision) {
776+
precision = math.max(precision, lit.getDecimal.getPrecision)
777+
}
778+
var scale = decimal.scale
779+
if (lit.getDecimal.hasScale) {
780+
scale = math.max(scale, lit.getDecimal.getScale)
781+
}
782+
DecimalType(math.max(precision, scale), scale)
783+
case proto.Expression.Literal.LiteralTypeCase.STRING =>
784+
StringType
785+
case proto.Expression.Literal.LiteralTypeCase.DATE =>
786+
DateType
787+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
788+
TimestampType
789+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
790+
TimestampNTZType
791+
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
792+
CalendarIntervalType
793+
case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
794+
YearMonthIntervalType()
795+
case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
796+
DayTimeIntervalType()
797+
case proto.Expression.Literal.LiteralTypeCase.TIME =>
798+
var precision = TimeType.DEFAULT_PRECISION
799+
if (lit.getTime.hasPrecision) {
800+
precision = lit.getTime.getPrecision
801+
}
802+
TimeType(precision)
803+
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
804+
DataTypeProtoConverter.toCatalystType(
805+
proto.DataType.newBuilder
806+
.setArray(LiteralValueProtoConverter.getProtoArrayType(lit.getArray))
807+
.build())
808+
case proto.Expression.Literal.LiteralTypeCase.MAP =>
809+
DataTypeProtoConverter.toCatalystType(
810+
proto.DataType.newBuilder
811+
.setMap(LiteralValueProtoConverter.getProtoMapType(lit.getMap))
812+
.build())
813+
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
814+
DataTypeProtoConverter.toCatalystType(
815+
proto.DataType.newBuilder
816+
.setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct))
817+
.build())
818+
case _ =>
819+
throw InvalidPlanInput(
820+
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
821+
s"(${lit.getLiteralTypeCase.getNumber})")
822+
}
823+
}
762824
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils
3131
import org.apache.spark.sql.connect.planner.InvalidInputErrors
3232
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService}
3333
import org.apache.spark.sql.connect.utils.ErrorUtils
34+
import org.apache.spark.sql.types.DataType
3435
import org.apache.spark.util.Utils
3536

3637
/**
@@ -227,21 +228,22 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
227228
executeHolder.request.getPlan.getDescriptorForType)
228229
}
229230

230-
val observedMetrics: Map[String, Seq[(Option[String], Any)]] = {
231+
val observedMetrics: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
231232
executeHolder.observations.map { case (name, observation) =>
232-
val values = observation.getOrEmpty.map { case (key, value) =>
233-
(Some(key), value)
234-
}.toSeq
233+
val values =
234+
observation.getRowOrEmpty
235+
.map(SparkConnectPlanExecution.toObservedMetricsValues(_))
236+
.getOrElse(Seq.empty)
235237
name -> values
236238
}.toMap
237239
}
238-
val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = {
240+
val accumulatedInPython: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
239241
executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator =>
240242
accumulator.synchronized {
241243
val value = accumulator.value.asScala.toSeq
242244
if (value.nonEmpty) {
243245
accumulator.reset()
244-
Some("__python_accumulator__" -> value.map(value => (None, value)))
246+
Some("__python_accumulator__" -> value.map(value => (None, value, None)))
245247
} else {
246248
None
247249
}

0 commit comments

Comments
 (0)