Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51429][Connect] Add "Acknowledgement" message to ExecutePlanResponse #50193

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 110 additions & 108 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,21 @@ class ExecutePlanResponse(google.protobuf.message.Message):
],
) -> None: ...

class Acknowledgement(google.protobuf.message.Message):
"""Server acknowledgement sent immediately upon registering an ExecutePlan or ReattachExecute
request.
This acknowledgement allows a client to disconnect right after registration, without waiting
for the full processing of the request.
It is especially useful when the server supports reattachment or otherwise, early termination
of the request.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

def __init__(
self,
) -> None: ...

SESSION_ID_FIELD_NUMBER: builtins.int
SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
OPERATION_ID_FIELD_NUMBER: builtins.int
Expand All @@ -1583,6 +1598,7 @@ class ExecutePlanResponse(google.protobuf.message.Message):
EXECUTION_PROGRESS_FIELD_NUMBER: builtins.int
CHECKPOINT_COMMAND_RESULT_FIELD_NUMBER: builtins.int
ML_COMMAND_RESULT_FIELD_NUMBER: builtins.int
ACKNOWLEDGEMENT_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
METRICS_FIELD_NUMBER: builtins.int
OBSERVED_METRICS_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -1650,6 +1666,11 @@ class ExecutePlanResponse(google.protobuf.message.Message):
def ml_command_result(self) -> pyspark.sql.connect.proto.ml_pb2.MlCommandResult:
"""ML command response"""
@property
def acknowledgement(self) -> global___ExecutePlanResponse.Acknowledgement:
"""Acknowledgement sent by the server immediately upon registration of an ExecutePlan or
ReattachExecute request.
"""
@property
def extension(self) -> google.protobuf.any_pb2.Any:
"""Support arbitrary result objects."""
@property
Expand Down Expand Up @@ -1692,6 +1713,7 @@ class ExecutePlanResponse(google.protobuf.message.Message):
execution_progress: global___ExecutePlanResponse.ExecutionProgress | None = ...,
checkpoint_command_result: global___CheckpointCommandResult | None = ...,
ml_command_result: pyspark.sql.connect.proto.ml_pb2.MlCommandResult | None = ...,
acknowledgement: global___ExecutePlanResponse.Acknowledgement | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
metrics: global___ExecutePlanResponse.Metrics | None = ...,
observed_metrics: collections.abc.Iterable[global___ExecutePlanResponse.ObservedMetrics]
Expand All @@ -1701,6 +1723,8 @@ class ExecutePlanResponse(google.protobuf.message.Message):
def HasField(
self,
field_name: typing_extensions.Literal[
"acknowledgement",
b"acknowledgement",
"arrow_batch",
b"arrow_batch",
"checkpoint_command_result",
Expand Down Expand Up @@ -1738,6 +1762,8 @@ class ExecutePlanResponse(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
"acknowledgement",
b"acknowledgement",
"arrow_batch",
b"arrow_batch",
"checkpoint_command_result",
Expand Down Expand Up @@ -1798,6 +1824,7 @@ class ExecutePlanResponse(google.protobuf.message.Message):
"execution_progress",
"checkpoint_command_result",
"ml_command_result",
"acknowledgement",
"extension",
]
| None
Expand Down
13 changes: 13 additions & 0 deletions sql/connect/common/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ message ExecutePlanResponse {
// ML command response
MlCommandResult ml_command_result = 20;

// Acknowledgement sent by the server immediately upon registration of an ExecutePlan or
// ReattachExecute request.
Acknowledgement acknowledgement = 21;

// Support arbitrary result objects.
google.protobuf.Any extension = 999;
}
Expand Down Expand Up @@ -477,6 +481,15 @@ message ExecutePlanResponse {
bool done = 5;
}
}

message Acknowledgement {
// Server acknowledgement sent immediately upon registering an ExecutePlan or ReattachExecute
// request.
// This acknowledgement allows a client to disconnect right after registration, without waiting
// for the full processing of the request.
// It is especially useful when the server supports reattachment or otherwise, early termination
// of the request.
}
}

// The key-value pair for the config request and response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T])
.setWriteStreamOperationStart(sinkBuilder.build())
.build()

val resp = ds.sparkSession.execute(startCmd).head
val resp = ds.sparkSession.execute(startCmd).find(_.hasWriteStreamOperationStartResult).get
if (resp.getWriteStreamOperationStartResult.hasQueryStartedEventJson) {
val event = QueryStartedEvent.fromJson(
resp.getWriteStreamOperationStartResult.getQueryStartedEventJson)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class RemoteStreamingQuery(
// Set command.
setCmdFn(queryCmdBuilder)

val resp = sparkSession.execute(cmdBuilder.build()).head
val resp =
sparkSession.execute(cmdBuilder.build()).find(_.hasStreamingQueryCommandResult).head

if (!resp.hasStreamingQueryCommandResult) {
throw new RuntimeException("Unexpected missing response for streaming query command")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession)
// Set command.
setCmdFn(managerCmdBuilder)

val resp = sparkSession.execute(cmdBuilder.build()).head
val resp =
sparkSession.execute(cmdBuilder.build()).find(_.hasStreamingQueryManagerCommandResult).head

if (!resp.hasStreamingQueryManagerCommandResult) {
throw new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
deadlineTimeNs = startTime + (1000L * 60L * 60L * 24L * 180L * NANOS_PER_MILLIS)
}

/**
* Enqueue an acknowledgement message to the response observer.
*/
private def enqueueAckResponse(): Unit = {
logDebug(s"Enqueue acknowledgement for opId=${executeHolder.operationId}")
val ackResponse = ExecutePlanResponse
.newBuilder()
.setAcknowledgement(ExecutePlanResponse.Acknowledgement.newBuilder().build())
.build()
executeHolder.responseObserver.tryOnNext(ackResponse)
}

/**
* Attach to the executionObserver, consume responses from it, and send them to grpcObserver.
*
Expand All @@ -203,6 +215,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
* that. 0 means start from beginning (since first response has index 1)
*/
def execute(lastConsumedStreamIndex: Long): Unit = {
enqueueAckResponse()
logInfo(
log"Starting for opId=${MDC(OP_ID, executeHolder.operationId)}, " +
log"reattachable=${MDC(REATTACHABLE, executeHolder.reattachable)}, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,4 +452,21 @@ class ReattachableExecuteSuite extends SparkConnectServerTest {
assert(re.getMessage.contains("INVALID_HANDLE.OPERATION_NOT_FOUND"))
}
}

test("Acknowledgement message is received") {
withRawBlockingStub { stub =>
val operationId = UUID.randomUUID().toString
val iter = stub.executePlan(
buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = operationId))
val response = iter.next()
assert(response.hasAcknowledgement)
assert(!iter.next().hasAcknowledgement)

// send reattach
val iter2 = stub.reattachExecute(buildReattachExecuteRequest(operationId, None))
val reattachResponse = iter2.next()
assert(reattachResponse.hasAcknowledgement)
assert(!iter2.next().hasAcknowledgement)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ class SparkConnectServiceSuite
assert(done)

// 4 Partitions + Metrics + optional progress messages
val filteredResponses = responses.filter(!_.hasExecutionProgress)
val filteredResponses =
responses.filter(x => !(x.hasExecutionProgress || x.hasAcknowledgement))
assert(filteredResponses.size == 6)

// Make sure the first response is schema only
Expand Down Expand Up @@ -302,7 +303,8 @@ class SparkConnectServiceSuite
assert(done)

// 1 Partitions + Metrics
val filteredResponses = responses.filter(!_.hasExecutionProgress)
val filteredResponses =
responses.filter(x => !(x.hasExecutionProgress || x.hasAcknowledgement))
assert(filteredResponses.size == 3)

// Make sure the first response is schema only
Expand Down Expand Up @@ -358,7 +360,8 @@ class SparkConnectServiceSuite
assert(done)

// 1 schema + 1 metric + at least 2 data batches
val filteredResponses = responses.filter(!_.hasExecutionProgress)
val filteredResponses =
responses.filter(x => !(x.hasExecutionProgress || x.hasAcknowledgement))
assert(filteredResponses.size > 3)

val allocator = new RootAllocator()
Expand Down Expand Up @@ -539,7 +542,8 @@ class SparkConnectServiceSuite
assert(done)

// Result + Metrics
val filteredResponses = responses.filter(!_.hasExecutionProgress)
val filteredResponses =
responses.filter(x => !(x.hasExecutionProgress || x.hasAcknowledgement))
if (filteredResponses.size > 1) {
assert(filteredResponses.size == 2)

Expand Down Expand Up @@ -793,7 +797,8 @@ class SparkConnectServiceSuite
// The current implementation is expected to be blocking. This is here to make sure it is.
assert(done)

val filteredResponses = responses.filter(!_.hasExecutionProgress)
val filteredResponses =
responses.filter(x => !(x.hasExecutionProgress || x.hasAcknowledgement))
assert(filteredResponses.size == 7)

// Make sure the first response is schema only
Expand Down