Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ class Observation(val name: String) {

private val isRegistered = new AtomicBoolean()

private val promise = Promise[Map[String, Any]]()
private val promise = Promise[Row]()

/**
* Future holding the (yet to be completed) observation.
*/
val future: Future[Map[String, Any]] = promise.future
val future: Future[Row] = promise.future

/**
* (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish its
Expand All @@ -76,7 +76,10 @@ class Observation(val name: String) {
* interrupted while waiting
*/
@throws[InterruptedException]
def get: Map[String, Any] = SparkThreadUtils.awaitResult(future, Duration.Inf)
def get: Map[String, Any] = {
val row = SparkThreadUtils.awaitResult(future, Duration.Inf)
row.getValuesMap(row.schema.map(_.name))
}

/**
* (Java-specific) Get the observed metrics. This waits for the observed dataset to finish its
Expand All @@ -99,7 +102,8 @@ class Observation(val name: String) {
*/
@throws[InterruptedException]
private[sql] def getOrEmpty: Map[String, Any] = {
Try(SparkThreadUtils.awaitResult(future, 100.millis)).getOrElse(Map.empty)
val row = getRowOrEmpty.getOrElse(Row.empty)
row.getValuesMap(row.schema.map(_.name))
}

/**
Expand All @@ -118,8 +122,17 @@ class Observation(val name: String) {
* `true` if all waiting threads were notified, `false` if otherwise.
*/
private[sql] def setMetricsAndNotify(metrics: Row): Boolean = {
val metricsMap = metrics.getValuesMap(metrics.schema.map(_.name))
promise.trySuccess(metricsMap)
promise.trySuccess(metrics)
}

/**
* Get the observed metrics as a Row.
*
* @return
* the observed metrics as a `Row`, or None if the metrics are not available.
*/
private[sql] def getRowOrEmpty: Option[Row] = {
Try(SparkThreadUtils.awaitResult(future, 100.millis)).toOption
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,42 @@ class ClientE2ETestSuite
val nullRows = nullResult.filter(_.getAs[Long]("id") >= 5)
assert(nullRows.forall(_.getAs[Int]("actual_p_id") == 0))
}

test("SPARK-53490: struct type in observed metrics") {
val observation = Observation("struct")
spark
.range(10)
.observe(observation, struct(count(lit(1)).as("rows"), max("id").as("maxid")).as("struct"))
.collect()
val expectedSchema =
StructType(Seq(StructField("rows", LongType), StructField("maxid", LongType)))
val expectedValue = new GenericRowWithSchema(Array(10, 9), expectedSchema)
assert(observation.get.size === 1)
assert(observation.get.contains("struct"))
assert(observation.get("struct") === expectedValue)
}

test("SPARK-53490: array type in observed metrics") {
val observation = Observation("array")
spark
.range(10)
.observe(observation, array(count(lit(1))).as("array"))
.collect()
assert(observation.get.size === 1)
assert(observation.get.contains("array"))
assert(observation.get("array") === Array(10))
}

test("SPARK-53490: map type in observed metrics") {
val observation = Observation("map")
spark
.range(10)
.observe(observation, map(lit("count"), count(lit(1))).as("map"))
.collect()
assert(observation.get.size === 1)
assert(observation.get.contains("map"))
assert(observation.get("map") === Map("count" -> 10))
}
}

private[sql] case class ClassData(a: String, b: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ private[sql] object SparkResult {
(0 until metric.getKeysCount).foreach { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toScalaValue(metric.getValues(i))
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
schema = schema.add(key, LiteralValueProtoConverter.getDataType(metric.getValues(i)))
values += value
}
new GenericRowWithSchema(values.result(), schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object DataTypeProtoConverter {
ArrayType(toCatalystType(t.getElementType), t.getContainsNull)
}

private def toCatalystStructType(t: proto.DataType.Struct): StructType = {
private[common] def toCatalystStructType(t: proto.DataType.Struct): StructType = {
val fields = t.getFieldsList.asScala.toSeq.map { protoField =>
val metadata = if (protoField.hasMetadata) {
Metadata.fromJson(protoField.getMetadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ import scala.util.Try
import com.google.protobuf.ByteString

import org.apache.spark.connect.proto
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.SparkClassUtils

object LiteralValueProtoConverter {

Expand Down Expand Up @@ -223,52 +224,51 @@ object LiteralValueProtoConverter {
val sb = builder.getStructBuilder
val fields = structType.fields

scalaValue match {
val iter = scalaValue match {
case p: Product =>
val iter = p.productIterator
var idx = 0
if (options.useDeprecatedDataTypeFields) {
while (idx < structType.size) {
val field = fields(idx)
// For backward compatibility, we need the data type for each field.
val literalProto = toLiteralProtoBuilderInternal(
iter.next(),
field.dataType,
options,
needDataType = true).build()
sb.addElements(literalProto)
idx += 1
}
sb.setStructType(toConnectProtoType(structType))
} else {
while (idx < structType.size) {
val field = fields(idx)
val literalProto =
toLiteralProtoBuilderInternal(iter.next(), field.dataType, options, needDataType)
.build()
sb.addElements(literalProto)

if (needDataType) {
val fieldBuilder = sb.getDataTypeStructBuilder
.addFieldsBuilder()
.setName(field.name)
.setNullable(field.nullable)

if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
fieldBuilder.setDataType(toConnectProtoType(field.dataType))
}

// Set metadata if available
if (field.metadata != Metadata.empty) {
fieldBuilder.setMetadata(field.metadata.json)
}
}
p.productIterator
case r: Row =>
r.toSeq.iterator
case other =>
throw new IllegalArgumentException(
s"literal ${other.getClass.getName}($other) not supported (yet).")
}

idx += 1
}
var idx = 0
if (options.useDeprecatedDataTypeFields) {
while (idx < structType.size) {
val field = fields(idx)
val literalProto =
toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
sb.addElements(literalProto)
idx += 1
}
sb.setStructType(toConnectProtoType(structType))
} else {
val dataTypeStruct = proto.DataType.Struct.newBuilder()
while (idx < structType.size) {
val field = fields(idx)
val literalProto =
toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
sb.addElements(literalProto)

val fieldBuilder = dataTypeStruct
.addFieldsBuilder()
.setName(field.name)
.setNullable(field.nullable)

if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
fieldBuilder.setDataType(toConnectProtoType(field.dataType))
}
case other =>
throw new IllegalArgumentException(s"literal $other not supported (yet).")

// Set metadata if available
if (field.metadata != Metadata.empty) {
fieldBuilder.setMetadata(field.metadata.json)
}

idx += 1
}
sb.setDataTypeStruct(dataTypeStruct.build())
}

sb
Expand Down Expand Up @@ -721,23 +721,12 @@ object LiteralValueProtoConverter {
private def toScalaStructInternal(
struct: proto.Expression.Literal.Struct,
structType: proto.DataType.Struct): Any = {
def toTuple[A <: Object](data: Seq[A]): Product = {
try {
val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product]
} catch {
case _: Exception =>
throw InvalidPlanInput(s"Unsupported Literal: ${data.mkString("Array(", ", ", ")")})")
}
}

val size = struct.getElementsCount
val structData = Seq.tabulate(size) { i =>
val structData = Array.tabulate(struct.getElementsCount) { i =>
val element = struct.getElements(i)
val dataType = structType.getFields(i).getDataType
getConverter(dataType)(element).asInstanceOf[Object]
getConverter(dataType)(element)
}
toTuple(structData)
new GenericRowWithSchema(structData, DataTypeProtoConverter.toCatalystStructType(structType))
}

def getProtoStructType(struct: proto.Expression.Literal.Struct): proto.DataType.Struct = {
Expand All @@ -759,4 +748,77 @@ object LiteralValueProtoConverter {
def toScalaStruct(struct: proto.Expression.Literal.Struct): Any = {
toScalaStructInternal(struct, getProtoStructType(struct))
}

def getDataType(lit: proto.Expression.Literal): DataType = {
lit.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.NULL =>
DataTypeProtoConverter.toCatalystType(lit.getNull)
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
BinaryType
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
BooleanType
case proto.Expression.Literal.LiteralTypeCase.BYTE =>
ByteType
case proto.Expression.Literal.LiteralTypeCase.SHORT =>
ShortType
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
IntegerType
case proto.Expression.Literal.LiteralTypeCase.LONG =>
LongType
case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
FloatType
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
DoubleType
case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
val decimal = Decimal.apply(lit.getDecimal.getValue)
var precision = decimal.precision
if (lit.getDecimal.hasPrecision) {
precision = math.max(precision, lit.getDecimal.getPrecision)
}
var scale = decimal.scale
if (lit.getDecimal.hasScale) {
scale = math.max(scale, lit.getDecimal.getScale)
}
DecimalType(math.max(precision, scale), scale)
case proto.Expression.Literal.LiteralTypeCase.STRING =>
StringType
case proto.Expression.Literal.LiteralTypeCase.DATE =>
DateType
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
TimestampType
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
TimestampNTZType
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
CalendarIntervalType
case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
YearMonthIntervalType()
case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
DayTimeIntervalType()
case proto.Expression.Literal.LiteralTypeCase.TIME =>
var precision = TimeType.DEFAULT_PRECISION
if (lit.getTime.hasPrecision) {
precision = lit.getTime.getPrecision
}
TimeType(precision)
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
DataTypeProtoConverter.toCatalystType(
proto.DataType.newBuilder
.setArray(LiteralValueProtoConverter.getProtoArrayType(lit.getArray))
.build())
case proto.Expression.Literal.LiteralTypeCase.MAP =>
DataTypeProtoConverter.toCatalystType(
proto.DataType.newBuilder
.setMap(LiteralValueProtoConverter.getProtoMapType(lit.getMap))
.build())
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
DataTypeProtoConverter.toCatalystType(
proto.DataType.newBuilder
.setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct))
.build())
case _ =>
throw InvalidPlanInput(
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
s"(${lit.getLiteralTypeCase.getNumber})")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.planner.InvalidInputErrors
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService}
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -227,21 +228,22 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
executeHolder.request.getPlan.getDescriptorForType)
}

val observedMetrics: Map[String, Seq[(Option[String], Any)]] = {
val observedMetrics: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
executeHolder.observations.map { case (name, observation) =>
val values = observation.getOrEmpty.map { case (key, value) =>
(Some(key), value)
}.toSeq
val values =
observation.getRowOrEmpty
.map(SparkConnectPlanExecution.toObservedMetricsValues(_))
.getOrElse(Seq.empty)
name -> values
}.toMap
}
val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = {
val accumulatedInPython: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator =>
accumulator.synchronized {
val value = accumulator.value.asScala.toSeq
if (value.nonEmpty) {
accumulator.reset()
Some("__python_accumulator__" -> value.map(value => (None, value)))
Some("__python_accumulator__" -> value.map(value => (None, value, None)))
} else {
None
}
Expand Down
Loading