Skip to content

Commit

Permalink
[SPARK-47380][CONNECT] Ensure on the server side that the SparkSessio…
Browse files Browse the repository at this point in the history
…n is the same

### What changes were proposed in this pull request?

In this PR we change the client behaviour to send the previously observed server session id so that the server can validate that the client used to talk with this specific session. Previously this was only validated on the client side which made the server actually execute the request for the wrong session before throwing on the client side (once the response from the server was obtained).

### Why are the changes needed?
The server can execute the client command on the wrong spark session before client figuring out it's the different session.

### Does this PR introduce _any_ user-facing change?
The error message now pops up differently (it used to be a slightly different message when validated on the client).

### How was this patch tested?
Existing unit tests, add new unit test, e2e test added, manual testing

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #45499 from nemanja-boric-databricks/workspace.

Authored-by: Nemanja Boric <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
nemanja-boric-databricks authored and hvanhovell committed Mar 18, 2024
1 parent a40940a commit 51e8634
Show file tree
Hide file tree
Showing 26 changed files with 655 additions and 220 deletions.
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,11 @@
"Operation not found."
]
},
"SESSION_CHANGED" : {
"message" : [
"The existing Spark server driver instance has restarted. Please reconnect."
]
},
"SESSION_CLOSED" : {
"message" : [
"Session was closed."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ message AnalyzePlanRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 17;

// (Required) User context
UserContext user_context = 2;

Expand Down Expand Up @@ -281,6 +287,12 @@ message ExecutePlanRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 8;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
Expand Down Expand Up @@ -443,6 +455,12 @@ message ConfigRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 8;

// (Required) User context
UserContext user_context = 2;

Expand Down Expand Up @@ -536,6 +554,12 @@ message AddArtifactsRequest {
// User context
UserContext user_context = 2;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 7;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
Expand Down Expand Up @@ -630,6 +654,12 @@ message ArtifactStatusesRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 5;

// User context
UserContext user_context = 2;

Expand Down Expand Up @@ -673,6 +703,12 @@ message InterruptRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 7;

// (Required) User context
UserContext user_context = 2;

Expand Down Expand Up @@ -738,6 +774,12 @@ message ReattachExecuteRequest {
// This must be an id of existing session.
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 6;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
Expand Down Expand Up @@ -772,6 +814,12 @@ message ReleaseExecuteRequest {
// This must be an id of existing session.
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 7;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
Expand Down Expand Up @@ -856,6 +904,12 @@ message FetchErrorDetailsRequest {
// The id should be a UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`.
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 5;

// User context
UserContext user_context = 2;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,27 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.internal.Logging

// This is common logic to be shared between different stub instances to validate responses as
// seen by the client.
// This is common logic to be shared between different stub instances to keep the server-side
// session id and to validate responses as seen by the client.
class ResponseValidator extends Logging {

// Server side session ID, used to detect if the server side session changed. This is set upon
// receiving the first response from the server. This value is used only for executions that
// do not use server-side streaming.
private var serverSideSessionId: Option[String] = None

// Returns the server side session ID, used to send it back to the server in the follow-up
// requests so the server can validate it session id against the previous requests.
def getServerSideSessionId: Option[String] = serverSideSessionId

/**
* Hijacks the stored server side session ID with the given suffix. Used for testing to make
* sure that server is validating the session ID.
*/
private[sql] def hijackServerSideSessionIdForTesting(suffix: String): Unit = {
serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix)
}

def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = {
val response = fn
val field = response.getDescriptorForType.findFieldByName("server_side_session_id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ private[sql] class SparkConnectClient(
// a new client will create a new session ID.
private[sql] val sessionId: String = configuration.sessionId.getOrElse(UUID.randomUUID.toString)

/**
* Hijacks the stored server side session ID with the given suffix. Used for testing to make
* sure that server is validating the session ID.
*/
private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = {
stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix)
}

private[sql] val artifactManager: ArtifactManager = {
new ArtifactManager(configuration, sessionId, bstub, stub)
}
Expand All @@ -73,6 +81,14 @@ private[sql] class SparkConnectClient(
private[sql] def uploadAllClassFileArtifacts(): Unit =
artifactManager.uploadAllClassFileArtifacts()

/**
* Returns the server-side session id obtained from the first request, if there was a request
* already.
*/
private def serverSideSessionId: Option[String] = {
stubState.responseValidator.getServerSideSessionId
}

/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
* @return
Expand All @@ -99,11 +115,11 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.addAllTags(tags.get.toSeq.asJava)
.build()
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
if (configuration.useReattachableExecute) {
bstub.executePlanReattachable(request)
bstub.executePlanReattachable(request.build())
} else {
bstub.executePlan(request)
bstub.executePlan(request.build())
}
}

Expand All @@ -119,8 +135,8 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.setUserContext(userContext)
.build()
bstub.config(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.config(request.build())
}

/**
Expand Down Expand Up @@ -207,8 +223,8 @@ private[sql] class SparkConnectClient(
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.build()
analyze(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
analyze(request.build())
}

private[sql] def interruptAll(): proto.InterruptResponse = {
Expand All @@ -218,8 +234,8 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL)
.build()
bstub.interrupt(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.interrupt(request.build())
}

private[sql] def interruptTag(tag: String): proto.InterruptResponse = {
Expand All @@ -230,8 +246,8 @@ private[sql] class SparkConnectClient(
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG)
.setOperationTag(tag)
.build()
bstub.interrupt(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.interrupt(request.build())
}

private[sql] def interruptOperation(id: String): proto.InterruptResponse = {
Expand All @@ -242,8 +258,8 @@ private[sql] class SparkConnectClient(
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID)
.setOperationId(id)
.build()
bstub.interrupt(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.interrupt(request.build())
}

private[sql] def releaseSession(): proto.ReleaseSessionResponse = {
Expand All @@ -252,8 +268,7 @@ private[sql] class SparkConnectClient(
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.build()
bstub.releaseSession(request)
bstub.releaseSession(request.build())
}

private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr

override def onNext(req: AddArtifactsRequest): Unit = try {
if (this.holder == null) {
val previousSessionId = req.hasClientObservedServerSideSessionId match {
case true => Some(req.getClientObservedServerSideSessionId)
case false => None
}
this.holder = SparkConnectService.getOrCreateIsolatedSession(
req.getUserContext.getUserId,
req.getSessionId)
req.getSessionId,
previousSessionId)
}

if (req.hasBeginChunk) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ private[connect] class SparkConnectAnalyzeHandler(
extends Logging {

def handle(request: proto.AnalyzePlanRequest): Unit = {
val previousSessionId = request.hasClientObservedServerSideSessionId match {
case true => Some(request.getClientObservedServerSideSessionId)
case false => None
}
val sessionHolder = SparkConnectService.getOrCreateIsolatedSession(
request.getUserContext.getUserId,
request.getSessionId)
request.getSessionId,
previousSessionId)
// `withSession` ensures that session-specific artifacts (such as JARs and class files) are
// available during processing (such as deserialization).
sessionHolder.withSession { _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,28 @@ class SparkConnectArtifactStatusesHandler(
val responseObserver: StreamObserver[proto.ArtifactStatusesResponse])
extends Logging {

protected def cacheExists(userId: String, sessionId: String, hash: String): Boolean = {
protected def cacheExists(
userId: String,
sessionId: String,
previouslySeenSessionId: Option[String],
hash: String): Boolean = {
val session = SparkConnectService
.getOrCreateIsolatedSession(userId, sessionId)
.getOrCreateIsolatedSession(userId, sessionId, previouslySeenSessionId)
.session
val blockManager = session.sparkContext.env.blockManager
blockManager.getStatus(CacheId(session.sessionUUID, hash)).isDefined
}

def handle(request: proto.ArtifactStatusesRequest): Unit = {
val previousSessionId = request.hasClientObservedServerSideSessionId match {
case true => Some(request.getClientObservedServerSideSessionId)
case false => None
}
val holder = SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
.getOrCreateIsolatedSession(
request.getUserContext.getUserId,
request.getSessionId,
previousSessionId)

val builder = proto.ArtifactStatusesResponse.newBuilder()
builder.setSessionId(holder.sessionId)
Expand All @@ -49,6 +60,7 @@ class SparkConnectArtifactStatusesHandler(
cacheExists(
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId,
previouslySeenSessionId = previousSessionId,
hash = name.stripPrefix("cache/"))
} else false
builder.putStatuses(name, status.setExists(exists).build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,16 @@ class SparkConnectConfigHandler(responseObserver: StreamObserver[proto.ConfigRes
extends Logging {

def handle(request: proto.ConfigRequest): Unit = {
val previousSessionId = request.hasClientObservedServerSideSessionId match {
case true => Some(request.getClientObservedServerSideSessionId)
case false => None
}
val holder =
SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
.getOrCreateIsolatedSession(
request.getUserContext.getUserId,
request.getSessionId,
previousSessionId)
val session = holder.session

val builder = request.getOperation.getOpTypeCase match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,15 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
* Create a new ExecuteHolder and register it with this global manager and with its session.
*/
private[connect] def createExecuteHolder(request: proto.ExecutePlanRequest): ExecuteHolder = {
val previousSessionId = request.hasClientObservedServerSideSessionId match {
case true => Some(request.getClientObservedServerSideSessionId)
case false => None
}
val sessionHolder = SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
.getOrCreateIsolatedSession(
request.getUserContext.getUserId,
request.getSessionId,
previousSessionId)
val executeHolder = new ExecuteHolder(request, sessionHolder)
executionsLock.synchronized {
// Check if the operation already exists, both in active executions, and in the graveyard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ class SparkConnectFetchErrorDetailsHandler(
responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]) {

def handle(v: proto.FetchErrorDetailsRequest): Unit = {
val previousSessionId = v.hasClientObservedServerSideSessionId match {
case true => Some(v.getClientObservedServerSideSessionId)
case false => None
}
val sessionHolder =
SparkConnectService
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId, previousSessionId)

val response = Option(sessionHolder.errorIdToError.getIfPresent(v.getErrorId))
.map { error =>
Expand Down
Loading

0 comments on commit 51e8634

Please sign in to comment.