Skip to content

[SPARK-51981][SS] Add JobTags to queryStartedEvent #50780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions python/pyspark/sql/streaming/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
import uuid
import json
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
from abc import ABC, abstractmethod

from pyspark.sql import Row
Expand Down Expand Up @@ -178,29 +178,44 @@ class QueryStartedEvent:
"""

def __init__(
self, id: uuid.UUID, runId: uuid.UUID, name: Optional[str], timestamp: str
self,
id: uuid.UUID,
runId: uuid.UUID,
name: Optional[str],
timestamp: str,
jobTags: Set[str],
) -> None:
self._id: uuid.UUID = id
self._runId: uuid.UUID = runId
self._name: Optional[str] = name
self._timestamp: str = timestamp
self._jobTags: Set[str] = jobTags

@classmethod
def fromJObject(cls, jevent: "JavaObject") -> "QueryStartedEvent":
job_tags = set()
java_iterator = jevent.jobTags().iterator()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no biggie but you can call set(jobTags().toList()) which will be automatically a Python list. Having individual Py4J call is actually pretty expensive. But tags are supposed to be few so I don't mind it. leave it to you with my approval.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, but it didn't works. See the comment I put above #50780 (comment). I don't know why though, this is the tests result I got when running locally

lingkai.kong@K9WHXLR93K spark % python/run-tests --testnames 'pyspark.sql.tests.streaming.test_streaming_listener StreamingListenerTests.test_listener_events'
Running PySpark tests. Output is in /Users/lingkai.kong/spark/python/unit-tests.log
Will test against the following Python executables: ['python3.9']
Will test the following Python tests: ['pyspark.sql.tests.streaming.test_streaming_listener StreamingListenerTests.test_listener_events']
python3.9 python_implementation is CPython
python3.9 version is: Python 3.9.6
Starting test(python3.9): pyspark.sql.tests.streaming.test_streaming_listener StreamingListenerTests.test_listener_events (temp output: /Users/lingkai.kong/spark/python/target/ed1f4b6e-6661-4815-84fd-00bf6cedd0ab/python3.9__pyspark.sql.tests.streaming.test_streaming_listener_StreamingListenerTests.test_listener_events__ck29eqqs.log)
WARNING: Using incubator modules: jdk.incubator.vector
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
test_listener_events (pyspark.sql.tests.streaming.test_streaming_listener.StreamingListenerTests) ... FAIL

======================================================================
FAIL: test_listener_events (pyspark.sql.tests.streaming.test_streaming_listener.StreamingListenerTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/lingkai.kong/spark/python/pyspark/sql/tests/streaming/test_streaming_listener.py", line 413, in test_listener_events
    verify(TestListenerV1())
  File "/Users/lingkai.kong/spark/python/pyspark/sql/tests/streaming/test_streaming_listener.py", line 396, in verify
    self.check_start_event(start_event)
  File "/Users/lingkai.kong/spark/python/pyspark/sql/tests/streaming/test_streaming_listener.py", line 40, in check_start_event
    self.assertTrue(isinstance(event, QueryStartedEvent))
AssertionError: False is not true

----------------------------------------------------------------------
Ran 1 test in 5.903s

FAILED (failures=1)

Had test failures in pyspark.sql.tests.streaming.test_streaming_listener StreamingListenerTests.test_listener_events with python3.9; see logs.

Could this be a issue with regarding to package version etc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okie that's fine

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg, thanks! Can you help me merge this PR once the tests pass?

while java_iterator.hasNext():
job_tags.add(java_iterator.next().toString())

return cls(
id=uuid.UUID(jevent.id().toString()),
runId=uuid.UUID(jevent.runId().toString()),
name=jevent.name(),
timestamp=jevent.timestamp(),
jobTags=job_tags,
)

@classmethod
def fromJson(cls, j: Dict[str, Any]) -> "QueryStartedEvent":
# Json4s will convert jobTags to a list, so we need to convert it back to a set.
job_tags = j["jobTags"] if "jobTags" in j else []
return cls(
id=uuid.UUID(j["id"]),
runId=uuid.UUID(j["runId"]),
name=j["name"],
timestamp=j["timestamp"],
jobTags=set(job_tags),
)

@property
Expand Down Expand Up @@ -233,6 +248,13 @@ def timestamp(self) -> str:
"""
return self._timestamp

@property
def jobTags(self) -> Set[str]:
"""
The job tags of the query.
"""
return self._jobTags


class QueryProgressEvent:
"""
Expand Down
35 changes: 27 additions & 8 deletions python/pyspark/sql/tests/streaming/test_streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def check_start_event(self, event):
datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
except ValueError:
self.fail("'%s' is not in ISO 8601 format.")
self.assertTrue(isinstance(event.jobTags, set))

def check_progress_event(self, event, is_stateful):
"""Check QueryProgressEvent"""
Expand Down Expand Up @@ -287,7 +288,7 @@ def get_number_of_public_methods(clz):
get_number_of_public_methods(
"org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"
),
15,
16,
msg,
)
self.assertEqual(
Expand Down Expand Up @@ -451,20 +452,38 @@ def verify(test_listener):
verify(TestListenerV2())

def test_query_started_event_fromJson(self):
start_event = """
start_event_old = """
{
"id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b",
"runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8",
"name" : null,
"timestamp" : "2023-06-09T18:13:29.741Z"
}
"""
start_event = QueryStartedEvent.fromJson(json.loads(start_event))
self.check_start_event(start_event)
self.assertEqual(start_event.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b"))
self.assertEqual(start_event.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8"))
self.assertIsNone(start_event.name)
self.assertEqual(start_event.timestamp, "2023-06-09T18:13:29.741Z")
start_event_old = QueryStartedEvent.fromJson(json.loads(start_event_old))
self.check_start_event(start_event_old)
self.assertEqual(start_event_old.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b"))
self.assertEqual(start_event_old.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8"))
self.assertIsNone(start_event_old.name)
self.assertEqual(start_event_old.timestamp, "2023-06-09T18:13:29.741Z")
self.assertEqual(start_event_old.jobTags, set())

start_event_new = """
{
"id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b",
"runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8",
"name" : null,
"timestamp" : "2023-06-09T18:13:29.741Z",
"jobTags": ["jobTag1", "jobTag2"]
}
"""
start_event_new = QueryStartedEvent.fromJson(json.loads(start_event_new))
self.check_start_event(start_event_new)
self.assertEqual(start_event_new.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b"))
self.assertEqual(start_event_new.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8"))
self.assertIsNone(start_event_new.name)
self.assertEqual(start_event_new.timestamp, "2023-06-09T18:13:29.741Z")
self.assertEqual(start_event_new.jobTags, set(["jobTag1", "jobTag2"]))

def test_query_terminated_event_fromJson(self):
terminated_json = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.sql.streaming

import java.util.UUID

import scala.jdk.CollectionConverters._

import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.databind.node.TreeTraversingParser
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
import org.json4s.{JObject, JString}
import org.json4s.{JArray, JObject, JString}
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc}
import org.json4s.jackson.JsonMethods.{compact, render}
Expand Down Expand Up @@ -123,6 +125,14 @@ object StreamingQueryListener extends Serializable {
private val tree = mapper.readTree(json)
def getString(name: String): String = tree.get(name).asText()
def getUUID(name: String): UUID = UUID.fromString(getString(name))
def getStringArray(name: String): List[String] = {
val node = tree.get(name)
if (node.isArray()) {
node.elements().asScala.map(_.asText()).toList
} else {
List()
}
}
def getProgress(name: String): StreamingQueryProgress = {
val parser = new TreeTraversingParser(tree.get(name), mapper)
parser.readValueAs(classOf[StreamingQueryProgress])
Expand All @@ -146,24 +156,32 @@ object StreamingQueryListener extends Serializable {
* User-specified name of the query, null if not specified.
* @param timestamp
* The timestamp to start a query.
* @param jobTags
* The job tags that have been assigned to all the jobs started by this thread
* @since 2.1.0
*/
@Evolving
class QueryStartedEvent private[sql] (
val id: UUID,
val runId: UUID,
val name: String,
val timestamp: String)
val timestamp: String,
val jobTags: Set[String])
extends Event
with Serializable {

def this(id: UUID, runId: UUID, name: String, timestamp: String) = {
this(id, runId, name, timestamp, Set())
}

def json: String = compact(render(jsonValue))

private def jsonValue: JValue = {
("id" -> JString(id.toString)) ~
("runId" -> JString(runId.toString)) ~
("name" -> JString(name)) ~
("timestamp" -> JString(timestamp))
("timestamp" -> JString(timestamp)) ~
("jobTags" -> JArray(jobTags.toList.map(JString)))
}
}

Expand All @@ -175,7 +193,8 @@ object StreamingQueryListener extends Serializable {
parser.getUUID("id"),
parser.getUUID("runId"),
parser.getString("name"),
parser.getString("name"))
parser.getString("timestamp"),
parser.getStringArray("jobTags").toSet)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,14 @@ abstract class StreamExecution(
// `postEvent` does not throw non fatal exception.
val startTimestamp = triggerClock.getTimeMillis()
postEvent(
new QueryStartedEvent(id, runId, name, progressReporter.formatTimestamp(startTimestamp)))
new QueryStartedEvent(
id,
runId,
name,
progressReporter.formatTimestamp(startTimestamp),
sparkSession.sparkContext.getJobTags()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many elements are we anticipating to have here? The size of event should be considerably small enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it should be quite small, we already have the job tags in other listener event already, for example here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explanation.

)
)

// Unblock starting thread
startLatch.countDown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,28 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
assert(newEvent.id === event.id)
assert(newEvent.runId === event.runId)
assert(newEvent.name === event.name)
assert(newEvent.timestamp === event.timestamp)
assert(newEvent.jobTags === event.jobTags)
}

testSerialization(
new QueryStartedEvent(UUID.randomUUID, UUID.randomUUID, "name", "2016-12-05T20:54:20.827Z"))
new QueryStartedEvent(
UUID.randomUUID,
UUID.randomUUID,
"name",
"2016-12-05T20:54:20.827Z",
Set()
)
)
testSerialization(
new QueryStartedEvent(UUID.randomUUID, UUID.randomUUID, null, "2016-12-05T20:54:20.827Z"))
new QueryStartedEvent(
UUID.randomUUID,
UUID.randomUUID,
null,
"2016-12-05T20:54:20.827Z",
Set("tag1", "tag2")
)
)
}

test("QueryProgressEvent serialization") {
Expand Down Expand Up @@ -349,6 +365,32 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
}
}

test("QueryStartedEvent has the right jobTags set") {
val session = spark.newSession()
val collector = new EventCollectorV2
val jobTag1 = "test-jobTag1"
val jobTag2 = "test-jobTag2"

def runQuery(session: SparkSession): Unit = {
collector.reset()
session.sparkContext.addJobTag(jobTag1)
session.sparkContext.addJobTag(jobTag2)
val mem = MemoryStream[Int](implicitly[Encoder[Int]], session.sqlContext)
testStream(mem.toDS())(
AddData(mem, 1, 2, 3),
CheckAnswer(1, 2, 3)
)
session.sparkContext.listenerBus.waitUntilEmpty()
session.sparkContext.clearJobTags()
}

withListenerAdded(collector, session) {
runQuery(session)
assert(collector.startEvent !== null)
assert(collector.startEvent.jobTags === Set(jobTag1, jobTag2))
}
}

test("listener only posts events from queries started in the related sessions") {
val session1 = spark.newSession()
val session2 = spark.newSession()
Expand Down