diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 47db260e6b09a..b9583c4a11fa9 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -202,10 +202,14 @@ def onQueryTerminated(self, e): q.stop() # need to wait a while before QueryTerminatedEvent reaches client - time.sleep(5) - self.assertTrue(len(listener_good.start) > 0) - self.assertTrue(len(listener_good.progress) > 0) - self.assertTrue(len(listener_good.terminated) > 0) + + @eventually(timeout=5, catch_assertions=True) + def check_listner(): + self.assertTrue(len(listener_good.start) > 0) + self.assertTrue(len(listener_good.progress) > 0) + self.assertTrue(len(listener_good.terminated) > 0) + + check_listner() finally: for listener in self.spark.streams._sqlb._listener_bus: self.spark.streams.removeListener(listener) @@ -245,26 +249,27 @@ def test_listener_events_spark_command(self): q.stop() self.assertFalse(q.isActive) - time.sleep( - 60 - ) # Sleep to make sure listener_terminated_events is written successfully - - start_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_start_events").collect()[0][0] - ) - - progress_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_progress_events").collect()[0][0] - ) - - terminated_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_terminated_events").collect()[0][0] - ) - - self.check_start_event(start_event) - self.check_progress_event(progress_event, is_stateful=True) - self.check_terminated_event(terminated_event) - + events = { + "start_event": None, + "progress_event": None, + "terminated_event": None, + } + + @eventually(timeout=60, catch_assertions=True) + def load_event(event_name, table_name): + table = self.spark.read.table(table_name).collect() + if len(table) == 0: + return False + events[event_name] = pyspark.cloudpickle.loads(table[0][0]) + return True + + load_event("start_event", "listener_start_events") + load_event("progress_event", "listener_progress_events") + load_event("terminated_event", "listener_terminated_events") + + self.check_start_event(events["start_event"]) + self.check_progress_event(events["progress_event"], is_stateful=True) + self.check_terminated_event(events["terminated_event"]) finally: self.spark.streams.removeListener(test_listener) # Remove again to verify this won't throw any error