Skip to content

Commit ef0c8b3

Browse files
committed
A bit of progress cleaning up the stream stream join examples: Open: @holden needs to update STB to handle the strongly typed Datasets that are not just valid type aliases to DataFrame.
1 parent 42542ff commit ef0c8b3

File tree

4 files changed

+147
-142
lines changed

4 files changed

+147
-142
lines changed

core/src/main/scala/com/high-performance-spark-examples/streaming/structuredstreaming/RateSourceStressExample.scala

Lines changed: 0 additions & 40 deletions
This file was deleted.

core/src/main/scala/com/high-performance-spark-examples/streaming/structuredstreaming/StreamStreamJoinBothSideWatermark.scala

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
package com.highperformancespark.examples.structuredstreaming
22

3-
// tag::stream_stream_join_basic_both_side_watermark[]
43
// Stream-stream join with watermark on both sides
54
// State can be cleaned up
65

7-
import org.apache.spark.sql.SparkSession
6+
import org.apache.spark.sql._
87
import org.apache.spark.sql.functions._
9-
import org.apache.spark.sql.streaming.Trigger
8+
import org.apache.spark.sql.streaming._
109

1110
object StreamStreamJoinBothSideWatermark {
1211
def main(args: Array[String]): Unit = {
@@ -15,33 +14,51 @@ object StreamStreamJoinBothSideWatermark {
1514
.appName("StreamStreamJoinBothSideWatermark")
1615
.master("local[2]")
1716
.getOrCreate()
18-
import spark.implicits._
17+
}
1918

19+
def run(spark: SparkSession): Unit = {
2020
val left = spark.readStream
2121
.format("memory")
2222
.load()
23-
.withWatermark("timestamp", "10 minutes")
23+
2424
val right = spark.readStream
2525
.format("memory")
2626
.load()
27-
.withWatermark("timestamp", "10 minutes")
27+
28+
val query = streamStreamJoin(spark, left, right)
29+
query.awaitTermination()
30+
}
31+
32+
def streamStreamJoinDF(spark: SparkSession, stream1: DataFrame, stream2: DataFrame): Dataset[Row] = {
33+
// Note the watermarks don't need to be the same, by default Spark will pick the min.
34+
// tag::stream_stream_join_basic_both_side_watermark[]
35+
val left = stream1.withWatermark("timestamp", "10 minutes")
36+
val right = stream2.withWatermark("timestamp", "5 minutes")
2837

2938
val joined = left.join(
3039
right,
3140
expr(
32-
"left.timestamp >= right.timestamp - interval 5 minutes AND left.timestamp <= right.timestamp + interval 5 minutes AND left.key = right.key"
41+
"left.timestamp >= right.timestamp - interval 5 minutes " +
42+
" AND left.timestamp <= right.timestamp + interval 5 minutes " +
43+
" AND left.key = right.key"
3344
)
3445
)
46+
// end::stream_stream_join_basic_both_side_watermark[]
47+
joined
48+
}
3549

36-
val query = joined.writeStream
50+
def streamStreamJoin(spark: SparkSession, stream1: DataFrame, stream2: DataFrame): StreamingQuery = {
51+
val joined = streamStreamJoinDF(spark, stream1, stream2)
52+
// tag::ex_with_checkpoin_at_writet[]
53+
val writer = joined.writeStream
3754
.outputMode("append")
3855
.format("console")
3956
.option(
4057
"checkpointLocation",
4158
"./tmp/checkpoints/stream_stream_join_both_side_watermark"
4259
)
43-
.start()
44-
query.awaitTermination()
60+
// end::ex_with_checkpoin_at_writet[]
61+
val query = writer.start()
62+
query
4563
}
4664
}
47-
// end::stream_stream_join_basic_both_side_watermark[]

core/src/test/scala/com/high-performance-spark-examples/streaming/structuredstreaming/RateSourceStressExampleSuite.scala

Lines changed: 0 additions & 42 deletions
This file was deleted.

core/src/test/scala/com/high-performance-spark-examples/streaming/structuredstreaming/StreamStreamJoinBothSideWatermarkSuite.scala

Lines changed: 119 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,130 @@
11
package com.highperformancespark.examples.structuredstreaming
22

3-
// tag::stream_stream_join_basic_both_side_watermark_test[]
4-
// Test for stream-stream join with watermark on both sides
5-
// Verifies bounded state and correct join results
3+
import java.sql.Timestamp
4+
import java.nio.file.Files
65

7-
import org.scalatest.funsuite.AnyFunSuite
8-
import org.apache.spark.sql.SparkSession
9-
import org.apache.spark.sql.streaming.Trigger
6+
import org.apache.spark.sql._
107
import org.apache.spark.sql.functions._
11-
import java.sql.Timestamp
8+
import org.apache.spark.sql.streaming._
9+
import org.apache.spark.sql.execution.streaming.MemoryStream
10+
import org.scalatest.funsuite.AnyFunSuite
11+
12+
// spark-testing-base
13+
import com.holdenkarau.spark.testing.DatasetSuiteBase
1214

13-
class StreamStreamJoinBothSideWatermarkSuite extends AnyFunSuite {
14-
test("join with both-side watermark yields bounded state and correct results") {
15-
val spark = SparkSession.builder()
16-
.master("local[2]")
17-
.appName("StreamStreamJoinBothSideWatermarkSuite")
18-
.getOrCreate()
19-
import spark.implicits._
20-
21-
import org.apache.spark.sql.execution.streaming.MemoryStream
22-
val now = System.currentTimeMillis()
23-
val leftStream = MemoryStream[(Timestamp, String)](1, spark.sqlContext)
24-
val rightStream = MemoryStream[(Timestamp, String)](2, spark.sqlContext)
25-
val leftRows = Seq(
26-
(new Timestamp(now - 1000 * 60 * 5), "k1"), // within window
27-
(new Timestamp(now - 1000 * 60 * 20), "k2") // late, beyond watermark
28-
)
29-
val rightRows = Seq(
30-
(new Timestamp(now - 1000 * 60 * 5), "k1"), // within window
31-
(new Timestamp(now - 1000 * 60 * 20), "k2") // late, beyond watermark
32-
)
33-
leftStream.addData(leftRows: _*)
34-
rightStream.addData(rightRows: _*)
35-
val leftDF = leftStream.toDF().toDF("timestamp", "key").withWatermark("timestamp", "10 minutes")
36-
val rightDF = rightStream.toDF().toDF("timestamp", "key").withWatermark("timestamp", "10 minutes")
37-
38-
val joined = leftDF.join(
39-
rightDF,
40-
leftDF("key") === rightDF("key") &&
41-
leftDF("timestamp") >= rightDF("timestamp") - expr("interval 5 minutes") &&
42-
leftDF("timestamp") <= rightDF("timestamp") + expr("interval 5 minutes")
43-
)
44-
45-
val query = joined.writeStream
15+
final case class Ev(key: String, timestamp: Timestamp, v: Int)
16+
17+
class StreamStreamJoinBothSideWatermarkSTBSpec
18+
extends AnyFunSuite
19+
with DatasetSuiteBase {
20+
21+
import spark.implicits._
22+
23+
private def ts(mins: Long): Timestamp =
24+
new Timestamp(mins * 60L * 1000L) // epoch + minutes
25+
26+
private def joinedDF(leftIn: DataFrame, rightIn: DataFrame): DataFrame = {
27+
StreamStreamJoinBothSideWatermark.streamStreamJoinDF(spark, leftIn, rightIn)
28+
}
29+
30+
test("joins rows with same key within ±5 minutes") {
31+
val leftMem = MemoryStream[Ev](1, spark.sqlContext)
32+
val rightMem = MemoryStream[Ev](2, spark.sqlContext)
33+
34+
val outName = "stb_out_basic"
35+
val q = joinedDF(leftMem.toDF(), rightMem.toDF())
36+
.writeStream
37+
.format("memory")
38+
.queryName(outName)
4639
.outputMode("append")
40+
.option("checkpointLocation", Files.createTempDirectory("chk-basic").toString)
41+
.start()
42+
43+
// Left @ 10, Right @ 12 -> within window and same key
44+
leftMem.addData(Ev("A", ts(10), 1))
45+
rightMem.addData(Ev("A", ts(12), 2))
46+
q.processAllAvailable()
47+
48+
// Select a stable set of columns to compare
49+
val actual = spark.table(outName)
50+
.selectExpr("left.key as key", "left.timestamp as lt", "right.timestamp as rt")
51+
.as[(String, Timestamp, Timestamp)]
52+
53+
val expected = Seq(("A", ts(10), ts(12))).toDS()
54+
55+
assertDataFrameEquals(actual, expected)
56+
57+
q.stop()
58+
}
59+
60+
test("does not join when outside tolerance or key mismatch") {
61+
val leftMem = MemoryStream[Ev](3, spark.sqlContext)
62+
val rightMem = MemoryStream[Ev](4, spark.sqlContext)
63+
64+
val outName = "stb_out_filtering"
65+
val q = joinedDF(leftMem.toDF(), rightMem.toDF())
66+
.writeStream
67+
.format("memory")
68+
.queryName(outName)
69+
.outputMode("append")
70+
.option("checkpointLocation", Files.createTempDirectory("chk-filter").toString)
71+
.start()
72+
73+
// Outside ±5 minutes (0 vs 7 -> 7 minutes apart)
74+
leftMem.addData(Ev("A", ts(0), 1))
75+
rightMem.addData(Ev("A", ts(7), 2))
76+
q.processAllAvailable()
77+
assert(spark.table(outName).isEmpty)
78+
79+
// Within time but different keys
80+
rightMem.addData(Ev("B", ts(2), 9))
81+
q.processAllAvailable()
82+
assert(spark.table(outName).isEmpty)
83+
84+
q.stop()
85+
}
86+
87+
test("late data are dropped after both watermarks advance") {
88+
val leftMem = MemoryStream[Ev](5, spark.sqlContext)
89+
val rightMem = MemoryStream[Ev](6, spark.sqlContext)
90+
91+
val outName = "stb_out_late"
92+
val q = joinedDF(leftMem.toDF(), rightMem.toDF())
93+
.writeStream
4794
.format("memory")
48-
.queryName("stream_stream_join_both_side_watermark")
49-
.trigger(Trigger.Once())
50-
.option("checkpointLocation", "./tmp/checkpoints/stream_stream_join_both_side_watermark_test")
95+
.queryName(outName)
96+
.outputMode("append")
97+
.option("checkpointLocation", Files.createTempDirectory("chk-late").toString)
5198
.start()
52-
query.processAllAvailable()
53-
query.awaitTermination()
5499

55-
val result = spark.sql("select key from stream_stream_join_both_side_watermark").collect().map(_.getString(0)).toSet
56-
assert(result == Set("k1"), "Only non-late key should join")
57-
spark.stop()
100+
// 1) Valid pair near t ~ 10..12
101+
leftMem.addData(Ev("A", ts(10), 1))
102+
rightMem.addData(Ev("A", ts(12), 2))
103+
q.processAllAvailable()
104+
assert(spark.table(outName).count() == 1)
105+
106+
// 2) Advance BOTH watermarks far ahead:
107+
// left WM delay 10m -> add t=100 -> WM ~ 90
108+
// right WM delay 5m -> add t=100 -> WM ~ 95
109+
leftMem.addData(Ev("A", ts(100), 3))
110+
rightMem.addData(Ev("A", ts(100), 4))
111+
q.processAllAvailable()
112+
113+
// 3) Inject events that would have joined in the past (t=20..22)
114+
// but are now far older than both watermarks -> should be dropped.
115+
leftMem.addData(Ev("A", ts(20), 5))
116+
rightMem.addData(Ev("A", ts(22), 6))
117+
q.processAllAvailable()
118+
119+
// Still only the first result
120+
assert(spark.table(outName).count() == 1)
121+
122+
// Optional sanity: state metrics shouldn't balloon
123+
Option(q.lastProgress).foreach { p =>
124+
assert(p.stateOperators != null && p.stateOperators.nonEmpty)
125+
assert(p.stateOperators.head.numRowsTotal >= 0)
126+
}
127+
128+
q.stop()
58129
}
59130
}
60-
// end::stream_stream_join_basic_both_side_watermark_test[]

0 commit comments

Comments
 (0)