diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala index b7573cb280444..1c6be6c2b1f0e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -22,6 +22,7 @@ import java.lang.{Long => JLong} import java.util.UUID import scala.jdk.CollectionConverters._ +import scala.math.BigDecimal.RoundingMode import scala.util.control.NonFatal import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} @@ -35,7 +36,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.Evolving import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.streaming.SafeJsonSerializer.{safeDoubleToJValue, safeMapToJValue} +import org.apache.spark.sql.streaming.SafeJsonSerializer.{safeDecimalToJValue, safeMapToJValue} import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS /** @@ -183,8 +184,8 @@ class StreamingQueryProgress private[spark] ( ("batchId" -> JInt(batchId)) ~ ("batchDuration" -> JInt(batchDuration)) ~ ("numInputRows" -> JInt(numInputRows)) ~ - ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ - ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~ + ("inputRowsPerSecond" -> safeDecimalToJValue(inputRowsPerSecond)) ~ + ("processedRowsPerSecond" -> safeDecimalToJValue(processedRowsPerSecond)) ~ ("durationMs" -> safeMapToJValue[JLong](durationMs, v => JInt(v.toLong))) ~ ("eventTime" -> safeMapToJValue[String](eventTime, s => JString(s))) ~ ("stateOperators" -> JArray(stateOperators.map(_.jsonValue).toList)) ~ @@ -255,8 +256,8 @@ class SourceProgress protected[spark] ( ("endOffset" -> tryParse(endOffset)) ~ ("latestOffset" -> tryParse(latestOffset)) ~ ("numInputRows" -> JInt(numInputRows)) ~ - ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ - ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~ + ("inputRowsPerSecond" -> safeDecimalToJValue(inputRowsPerSecond)) ~ + ("processedRowsPerSecond" -> safeDecimalToJValue(processedRowsPerSecond)) ~ ("metrics" -> safeMapToJValue[String](metrics, s => JString(s))) } @@ -316,6 +317,8 @@ private[sql] object SinkProgress { } private object SafeJsonSerializer { + + /** Convert Double to JValue while handling empty or infinite values */ def safeDoubleToJValue(value: Double): JValue = { if (value.isNaN || value.isInfinity) JNothing else JDouble(value) } @@ -326,4 +329,10 @@ private object SafeJsonSerializer { val keys = map.asScala.keySet.toSeq.sorted keys.map { k => k -> valueToJValue(map.get(k)): JObject }.reduce(_ ~ _) } + + /** Convert BigDecimal to JValue while handling empty or infinite values */ + def safeDecimalToJValue(value: Double): JValue = { + if (value.isNaN || value.isInfinity) JNothing + else JDecimal(BigDecimal(value).setScale(1, RoundingMode.HALF_UP)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index e748ae8e7d7df..3a1dfd3088fa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -23,10 +23,12 @@ import java.time.temporal.ChronoUnit import java.util.UUID import scala.jdk.CollectionConverters._ +import scala.math.BigDecimal.RoundingMode import org.json4s.jackson.JsonMethods._ import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.matchers.should.Matchers import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.Row @@ -40,7 +42,7 @@ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.StructType import org.apache.spark.util.ArrayImplicits._ -class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { +class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually with Matchers { test("StreamingQueryProgress - prettyJson") { val json1 = testProgress1.prettyJson assertJson( @@ -400,6 +402,40 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { assert(data(0).getAs[Timestamp](0).equals(validValue)) } + test("SPARK-53491: inputRowsPerSecond and processedRowsPerSecond " + + "should never be with scientific notation") { + val progress = testProgress4.jsonValue + + // Actual values + val inputRowsPerSecond: Double = 6.923076923076923E8 + val processedRowsPerSecond: Double = 2.923076923076923E8 + + // Get values from progress metrics JSON and cast back to Double + // for numeric comparison + val inputRowsPerSecondJSON = (progress \ "inputRowsPerSecond").values.toString + .toDouble + val processedRowsPerSecondJSON = (progress \ "processedRowsPerSecond").values.toString + .toDouble + + // Get expected values after type casting + val inputRowsPerSecondExpected = BigDecimal(inputRowsPerSecond) + .setScale(1, RoundingMode.HALF_UP).toDouble + val processedRowsPerSecondExpected = BigDecimal(processedRowsPerSecond) + .setScale(1, RoundingMode.HALF_UP).toDouble + + // This should fail if inputRowsPerSecond contains E notation + (progress \ "inputRowsPerSecond").values.toString should not include "E" + + // This should fail if processedRowsPerSecond contains E notation + (progress \ "processedRowsPerSecond").values.toString should not include "E" + + // Value in progress metrics should be equal to the Decimal conversion of the same + // Using epsilon to compare floating-point values + val epsilon = 1e-6 + inputRowsPerSecondJSON shouldBe inputRowsPerSecondExpected +- epsilon + processedRowsPerSecondJSON shouldBe processedRowsPerSecondExpected +- epsilon + } + def waitUntilBatchProcessed: AssertOnQuery = Execute { q => eventually(Timeout(streamingTimeout)) { if (q.exception.isEmpty) { @@ -522,6 +558,44 @@ object StreamingQueryStatusAndProgressSuite { observedMetrics = null ) + val testProgress4 = new StreamingQueryProgress( + id = UUID.randomUUID, + runId = UUID.randomUUID, + name = "myName", + timestamp = "2025-09-05T20:54:20.827Z", + batchId = 2L, + batchDuration = 0L, + durationMs = new java.util.HashMap(Map("total" -> 0L).transform((_, v) => long2Long(v)).asJava), + eventTime = new java.util.HashMap(Map( + "max" -> "2025-09-05T20:54:20.827Z", + "min" -> "2025-09-05T20:54:20.827Z", + "avg" -> "2025-09-05T20:54:20.827Z", + "watermark" -> "2025-09-05T20:54:20.827Z").asJava), + stateOperators = Array(new StateOperatorProgress(operatorName = "op1", + numRowsTotal = 0, numRowsUpdated = 1, allUpdatesTimeMs = 1, numRowsRemoved = 2, + allRemovalsTimeMs = 34, commitTimeMs = 23, memoryUsedBytes = 3, numRowsDroppedByWatermark = 0, + numShufflePartitions = 2, numStateStoreInstances = 2, + customMetrics = new java.util.HashMap(Map("stateOnCurrentVersionSizeBytes" -> 2L, + "loadedMapCacheHitCount" -> 1L, "loadedMapCacheMissCount" -> 0L) + .transform((_, v) => long2Long(v)).asJava) + )), + sources = Array( + new SourceProgress( + description = "source", + startOffset = "123", + endOffset = "456", + latestOffset = "789", + numInputRows = 678, + inputRowsPerSecond = 6.923076923076923E8, // Large double value having exponentials + processedRowsPerSecond = 2.923076923076923E8 + ) + ), + sink = SinkProgress("sink", None), + observedMetrics = new java.util.HashMap(Map( + "event1" -> row(schema1, 1L, 3.0d), + "event2" -> row(schema2, 1L, "hello", "world")).asJava) + ) + val testStatus = new StreamingQueryStatus("active", true, false) }