Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.lang.{Long => JLong}
import java.util.UUID

import scala.jdk.CollectionConverters._
import scala.math.BigDecimal.RoundingMode
import scala.util.control.NonFatal

import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
Expand All @@ -35,7 +36,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.Evolving
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.streaming.SafeJsonSerializer.{safeDoubleToJValue, safeMapToJValue}
import org.apache.spark.sql.streaming.SafeJsonSerializer.{safeDecimalToJValue, safeMapToJValue}
import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS

/**
Expand Down Expand Up @@ -183,8 +184,8 @@ class StreamingQueryProgress private[spark] (
("batchId" -> JInt(batchId)) ~
("batchDuration" -> JInt(batchDuration)) ~
("numInputRows" -> JInt(numInputRows)) ~
("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~
("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~
("inputRowsPerSecond" -> safeDecimalToJValue(inputRowsPerSecond)) ~
("processedRowsPerSecond" -> safeDecimalToJValue(processedRowsPerSecond)) ~
("durationMs" -> safeMapToJValue[JLong](durationMs, v => JInt(v.toLong))) ~
("eventTime" -> safeMapToJValue[String](eventTime, s => JString(s))) ~
("stateOperators" -> JArray(stateOperators.map(_.jsonValue).toList)) ~
Expand Down Expand Up @@ -255,8 +256,8 @@ class SourceProgress protected[spark] (
("endOffset" -> tryParse(endOffset)) ~
("latestOffset" -> tryParse(latestOffset)) ~
("numInputRows" -> JInt(numInputRows)) ~
("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~
("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~
("inputRowsPerSecond" -> safeDecimalToJValue(inputRowsPerSecond)) ~
("processedRowsPerSecond" -> safeDecimalToJValue(processedRowsPerSecond)) ~
("metrics" -> safeMapToJValue[String](metrics, s => JString(s)))
}

Expand Down Expand Up @@ -316,6 +317,8 @@ private[sql] object SinkProgress {
}

private object SafeJsonSerializer {

/** Convert Double to JValue while handling empty or infinite values */
def safeDoubleToJValue(value: Double): JValue = {
if (value.isNaN || value.isInfinity) JNothing else JDouble(value)
}
Expand All @@ -326,4 +329,10 @@ private object SafeJsonSerializer {
val keys = map.asScala.keySet.toSeq.sorted
keys.map { k => k -> valueToJValue(map.get(k)): JObject }.reduce(_ ~ _)
}

/** Convert BigDecimal to JValue while handling empty or infinite values */
def safeDecimalToJValue(value: Double): JValue = {
if (value.isNaN || value.isInfinity) JNothing
else JDecimal(BigDecimal(value).setScale(1, RoundingMode.HALF_UP))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import java.time.temporal.ChronoUnit
import java.util.UUID

import scala.jdk.CollectionConverters._
import scala.math.BigDecimal.RoundingMode

import org.json4s.jackson.JsonMethods._
import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.matchers.should.Matchers
import org.scalatest.time.SpanSugar._

import org.apache.spark.sql.Row
Expand All @@ -40,7 +42,7 @@ import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._

class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually with Matchers {
test("StreamingQueryProgress - prettyJson") {
val json1 = testProgress1.prettyJson
assertJson(
Expand Down Expand Up @@ -400,6 +402,40 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
assert(data(0).getAs[Timestamp](0).equals(validValue))
}

test("SPARK-53491: inputRowsPerSecond and processedRowsPerSecond " +
"should never be with scientific notation") {
val progress = testProgress4.jsonValue

// Actual values
val inputRowsPerSecond: Double = 6.923076923076923E8
val processedRowsPerSecond: Double = 2.923076923076923E8

// Get values from progress metrics JSON and cast back to Double
// for numeric comparison
val inputRowsPerSecondJSON = (progress \ "inputRowsPerSecond").values.toString
.toDouble
val processedRowsPerSecondJSON = (progress \ "processedRowsPerSecond").values.toString
.toDouble

// Get expected values after type casting
val inputRowsPerSecondExpected = BigDecimal(inputRowsPerSecond)
.setScale(1, RoundingMode.HALF_UP).toDouble
val processedRowsPerSecondExpected = BigDecimal(processedRowsPerSecond)
.setScale(1, RoundingMode.HALF_UP).toDouble

// This should fail if inputRowsPerSecond contains E notation
(progress \ "inputRowsPerSecond").values.toString should not include "E"

// This should fail if processedRowsPerSecond contains E notation
(progress \ "processedRowsPerSecond").values.toString should not include "E"

// Value in progress metrics should be equal to the Decimal conversion of the same
// Using epsilon to compare floating-point values
val epsilon = 1e-6
inputRowsPerSecondJSON shouldBe inputRowsPerSecondExpected +- epsilon
processedRowsPerSecondJSON shouldBe processedRowsPerSecondExpected +- epsilon
}

def waitUntilBatchProcessed: AssertOnQuery = Execute { q =>
eventually(Timeout(streamingTimeout)) {
if (q.exception.isEmpty) {
Expand Down Expand Up @@ -522,6 +558,44 @@ object StreamingQueryStatusAndProgressSuite {
observedMetrics = null
)

val testProgress4 = new StreamingQueryProgress(
id = UUID.randomUUID,
runId = UUID.randomUUID,
name = "myName",
timestamp = "2025-09-05T20:54:20.827Z",
batchId = 2L,
batchDuration = 0L,
durationMs = new java.util.HashMap(Map("total" -> 0L).transform((_, v) => long2Long(v)).asJava),
eventTime = new java.util.HashMap(Map(
"max" -> "2025-09-05T20:54:20.827Z",
"min" -> "2025-09-05T20:54:20.827Z",
"avg" -> "2025-09-05T20:54:20.827Z",
"watermark" -> "2025-09-05T20:54:20.827Z").asJava),
stateOperators = Array(new StateOperatorProgress(operatorName = "op1",
numRowsTotal = 0, numRowsUpdated = 1, allUpdatesTimeMs = 1, numRowsRemoved = 2,
allRemovalsTimeMs = 34, commitTimeMs = 23, memoryUsedBytes = 3, numRowsDroppedByWatermark = 0,
numShufflePartitions = 2, numStateStoreInstances = 2,
customMetrics = new java.util.HashMap(Map("stateOnCurrentVersionSizeBytes" -> 2L,
"loadedMapCacheHitCount" -> 1L, "loadedMapCacheMissCount" -> 0L)
.transform((_, v) => long2Long(v)).asJava)
)),
sources = Array(
new SourceProgress(
description = "source",
startOffset = "123",
endOffset = "456",
latestOffset = "789",
numInputRows = 678,
inputRowsPerSecond = 6.923076923076923E8, // Large double value having exponentials
processedRowsPerSecond = 2.923076923076923E8
)
),
sink = SinkProgress("sink", None),
observedMetrics = new java.util.HashMap(Map(
"event1" -> row(schema1, 1L, 3.0d),
"event2" -> row(schema2, 1L, "hello", "world")).asJava)
)

val testStatus = new StreamingQueryStatus("active", true, false)
}