Skip to content

Commit a0050e0

Browse files
committed
[query] run ServiceBackend in Py4jQueryDriver
1 parent 6f58b61 commit a0050e0

File tree

12 files changed

+374
-66
lines changed

12 files changed

+374
-66
lines changed

batch/batch/front_end/front_end.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,11 +1842,15 @@ async def create_update(request, userdata):
18421842
update_id, start_job_group_id, start_job_id = await _create_batch_update(
18431843
batch_id, update_spec['token'], n_jobs, n_job_groups, user, db
18441844
)
1845-
return json_response({
1845+
1846+
response = {
18461847
'update_id': update_id,
18471848
'start_job_group_id': start_job_group_id,
18481849
'start_job_id': start_job_id,
1849-
})
1850+
}
1851+
1852+
log.info(response)
1853+
return json_response(response)
18501854

18511855

18521856
async def _create_batch_update(
@@ -1866,7 +1870,7 @@ async def update(tx: Transaction):
18661870
)
18671871

18681872
if record:
1869-
return (record['update_id'], record['start_job_id'], record['start_job_group_id'])
1873+
return (record['update_id'], record['start_job_group_id'], record['start_job_id'])
18701874

18711875
# We use FOR UPDATE so that we serialize batch update insertions
18721876
# This is necessary to reserve job id and job group id ranges.

hail/hail/src/is/hail/backend/service/ServiceBackend.scala

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package is.hail.backend.service
22

3+
import is.hail.HAIL_REVISION
34
import is.hail.backend._
45
import is.hail.backend.Backend.PartitionFn
56
import is.hail.backend.local.LocalTaskContext
@@ -12,6 +13,7 @@ import is.hail.expr.ir.analyses.SemanticHash
1213
import is.hail.expr.ir.lowering._
1314
import is.hail.services._
1415
import is.hail.services.JobGroupStates.{Cancelled, Failure, Success}
16+
import is.hail.services.oauth2.{CloudCredentials, HailCredentials}
1517
import is.hail.types._
1618
import is.hail.types.physical._
1719
import is.hail.utils._
@@ -26,8 +28,65 @@ import scala.util.control.NonFatal
2628
import java.io._
2729
import java.util.concurrent.Executors
2830

31+
import com.fasterxml.jackson.core.StreamReadConstraints
32+
2933
object ServiceBackend {
3034
val MaxAvailableGcsConnections = 1000
35+
36+
// See https://github.com/hail-is/hail/issues/14580
37+
StreamReadConstraints.overrideDefaultStreamReadConstraints(
38+
StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build()
39+
)
40+
41+
def pyServiceBackend(
42+
name: String,
43+
batchId_ : Integer,
44+
billingProject: String,
45+
deployConfigFile: String,
46+
workerCores: String,
47+
workerMemory: String,
48+
storage: String,
49+
cloudfuse: Array[CloudfuseConfig],
50+
regions: Array[String],
51+
): ServiceBackend = {
52+
val credentials: CloudCredentials =
53+
HailCredentials().getOrElse(CloudCredentials(keyPath = None))
54+
55+
val client =
56+
BatchClient(
57+
DeployConfig.fromConfigFile(deployConfigFile),
58+
credentials,
59+
)
60+
61+
val batchId =
62+
Option(batchId_).map(_.toInt).getOrElse {
63+
client.newBatch(
64+
BatchRequest(
65+
billing_project = billingProject,
66+
token = tokenUrlSafe,
67+
n_jobs = 0,
68+
attributes = Map("name" -> name),
69+
)
70+
)
71+
}
72+
73+
val workerConfig =
74+
BatchJobConfig(
75+
workerCores,
76+
workerMemory,
77+
storage,
78+
cloudfuse,
79+
regions,
80+
)
81+
82+
new ServiceBackend(
83+
name,
84+
client,
85+
GitRevision(HAIL_REVISION),
86+
BatchConfig(batchId, 0),
87+
workerConfig,
88+
)
89+
}
3190
}
3291

3392
case class BatchJobConfig(
@@ -122,10 +181,16 @@ class ServiceBackend(
122181
)
123182

124183
stageCount += 1
125-
126-
Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms
127-
val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
128-
(response, startJobId)
184+
try {
185+
Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms
186+
val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
187+
(response, startJobId)
188+
} catch {
189+
case _: InterruptedException =>
190+
batchClient.cancelJobGroup(batchConfig.batchId, jobGroupId)
191+
Thread.currentThread().interrupt()
192+
throw new CancellationException()
193+
}
129194
}
130195

131196
override def mapCollectPartitions(

hail/hail/src/is/hail/services/BatchClient.scala

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import is.hail.services.BatchClient.{
77
}
88
import is.hail.services.JobGroupStates.isTerminal
99
import is.hail.services.oauth2.CloudCredentials
10-
import is.hail.services.requests.Requester
10+
import is.hail.services.requests.{ClientResponseException, Requester}
1111
import is.hail.utils._
1212

1313
import scala.collection.immutable.Stream.cons
@@ -270,22 +270,33 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
270270
batchId
271271
}
272272

273-
def newJobGroup(req: JobGroupRequest): (Int, Int) = {
274-
val nJobs = req.jobs.length
275-
val (updateId, startJobGroupId, startJobId) = beginUpdate(req.batch_id, req.token, nJobs)
276-
log.info(s"Began update '$updateId' for batch '${req.batch_id}'.")
277-
278-
createJobGroup(updateId, req)
279-
log.info(s"Created job group $startJobGroupId for batch ${req.batch_id}")
280-
281-
createJobsIncremental(req.batch_id, updateId, req.jobs)
282-
log.info(s"Submitted $nJobs in job group $startJobGroupId for batch ${req.batch_id}")
283-
284-
commitUpdate(req.batch_id, updateId)
285-
log.info(s"Committed update $updateId for batch ${req.batch_id}.")
286-
287-
(startJobGroupId, startJobId)
288-
}
273+
def newJobGroup(req: JobGroupRequest): (Int, Int) =
274+
retryable { attempts =>
275+
try {
276+
val nJobs = req.jobs.length
277+
val (updateId, startJobGroupId, startJobId) = beginUpdate(req.batch_id, req.token, nJobs)
278+
log.info(s"Began update '$updateId' for batch '${req.batch_id}'.")
279+
280+
createJobGroup(updateId, req)
281+
log.info(s"Created job group $startJobGroupId for batch ${req.batch_id}")
282+
283+
createJobsIncremental(req.batch_id, updateId, req.jobs)
284+
log.info(s"Submitted $nJobs in job group $startJobGroupId for batch ${req.batch_id}")
285+
286+
commitUpdate(req.batch_id, updateId)
287+
log.info(s"Committed update $updateId for batch ${req.batch_id}.")
288+
289+
(startJobGroupId, startJobId)
290+
} catch {
291+
case e: ClientResponseException
292+
if attempts < 5 &&
293+
e.status == 400 &&
294+
e.getMessage.contains("job group specs were not submitted in order") =>
295+
Thread.sleep(delayMsForTry(attempts))
296+
log.info("retry", e)
297+
retry
298+
}
299+
}
289300

290301
def getJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse =
291302
req

hail/hail/src/is/hail/services/package.scala

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import is.hail.services.requests.ClientResponseException
44
import is.hail.shadedazure.com.azure.storage.common.implementation.Constants
55
import is.hail.utils._
66

7+
import scala.annotation.tailrec
78
import scala.util.Random
89

910
import java.io._
@@ -30,31 +31,24 @@ package object services {
3031

3132
private[this] val LOG_2_MAX_MULTIPLIER =
3233
30 // do not set larger than 30 due to integer overflow calculating multiplier
33-
private[this] val DEFAULT_MAX_DELAY_MS = 60000
34-
private[this] val DEFAULT_BASE_DELAY_MS = 1000
34+
private[this] val DEFAULT_MAX_DELAY_MS = 60000L
35+
private[this] val DEFAULT_BASE_DELAY_MS = 1000L
3536

3637
def delayMsForTry(
3738
tries: Int,
38-
baseDelayMs: Int = DEFAULT_BASE_DELAY_MS,
39-
maxDelayMs: Int = DEFAULT_MAX_DELAY_MS,
40-
): Int = {
39+
baseDelayMs: Long = DEFAULT_BASE_DELAY_MS,
40+
maxDelayMs: Long = DEFAULT_MAX_DELAY_MS,
41+
): Long = {
4142
// Based on AWS' recommendations:
4243
// - https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
4344
/* -
4445
* https://github.com/aws/aws-sdk-java/blob/master/aws-java-sdk-core/src/main/java/com/amazonaws/retry/PredefinedBackoffStrategies.java */
4546
val multiplier = 1L << math.min(tries, LOG_2_MAX_MULTIPLIER)
46-
val ceilingForDelayMs = math.min(baseDelayMs * multiplier, maxDelayMs.toLong).toInt
47-
val proposedDelayMs = ceilingForDelayMs / 2 + Random.nextInt(ceilingForDelayMs / 2 + 1)
48-
return proposedDelayMs
47+
val ceilingForDelayMs = math.min(baseDelayMs * multiplier, maxDelayMs).toInt
48+
ceilingForDelayMs / 2L + Random.nextInt(ceilingForDelayMs / 2 + 1)
4949
}
5050

51-
def sleepBeforTry(
52-
tries: Int,
53-
baseDelayMs: Int = DEFAULT_BASE_DELAY_MS,
54-
maxDelayMs: Int = DEFAULT_MAX_DELAY_MS,
55-
) =
56-
Thread.sleep(delayMsForTry(tries, baseDelayMs, maxDelayMs).toLong)
57-
51+
@tailrec
5852
def isLimitedRetriesError(_e: Throwable): Boolean = {
5953
// An exception is a "retry once error" if a rare, known bug in a dependency or in a cloud
6054
// provider can manifest as this exception *and* that manifestation is indistinguishable from a
@@ -94,6 +88,7 @@ package object services {
9488
}
9589
}
9690

91+
@tailrec
9792
def isTransientError(_e: Throwable): Boolean = {
9893
// ReactiveException is package private inside reactore.core.Exception so we cannot access
9994
// it directly for an isInstance check. AFAICT, this is the only way to check if we received
@@ -185,14 +180,11 @@ package object services {
185180
}
186181
}
187182

188-
def retryTransientErrors[T](f: => T, reset: Option[() => Unit] = None): T = {
189-
var tries = 0
190-
while (true) {
191-
try
192-
return f
183+
def retryTransientErrors[T](f: => T, reset: Option[() => Unit] = None): T =
184+
retryable { tries =>
185+
try f
193186
catch {
194187
case e: Exception =>
195-
tries += 1
196188
val delay = delayMsForTry(tries)
197189
if (tries <= 5 && isLimitedRetriesError(e)) {
198190
log.warn(
@@ -205,14 +197,12 @@ package object services {
205197
} else if (tries % 10 == 0) {
206198
log.warn(s"Encountered $tries transient errors, most recent one was $e.")
207199
}
208-
Thread.sleep(delay.toLong)
200+
Thread.sleep(delay)
201+
reset.foreach(_())
202+
retry
209203
}
210-
reset.foreach(_())
211204
}
212205

213-
throw new AssertionError("unreachable")
214-
}
215-
216206
def formatException(e: Throwable): String = {
217207
using(new StringWriter()) { sw =>
218208
using(new PrintWriter(sw)) { pw =>

hail/hail/src/is/hail/utils/StringSocketAppender.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@ object StringSocketAppender {
1313
// low reconnection delay because everything is local
1414
val DEFAULT_RECONNECTION_DELAY = 100
1515

16-
var theAppender: StringSocketAppender = _
16+
private var theAppender: StringSocketAppender = _
1717

18-
def get(): StringSocketAppender = theAppender
18+
def get(): StringSocketAppender =
19+
synchronized {
20+
if (theAppender == null) theAppender = new StringSocketAppender
21+
theAppender
22+
}
1923
}
2024

21-
class StringSocketAppender() extends AppenderSkeleton {
25+
class StringSocketAppender extends AppenderSkeleton {
2226
private var address: InetAddress = _
2327
private var port: Int = _
2428
private var os: OutputStream = _
@@ -39,7 +43,7 @@ class StringSocketAppender() extends AppenderSkeleton {
3943

4044
override def close(): Unit = {
4145
if (closed) return
42-
this.closed = true
46+
closed = true
4347
cleanUp()
4448
}
4549

@@ -58,6 +62,10 @@ class StringSocketAppender() extends AppenderSkeleton {
5862
connector.interrupted = true
5963
connector = null // allow gc
6064
}
65+
66+
StringSocketAppender.synchronized {
67+
StringSocketAppender.theAppender = null
68+
}
6169
}
6270

6371
private def connect(address: InetAddress, port: Int): Unit = {

hail/hail/src/is/hail/utils/package.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import scala.collection.{mutable, GenTraversableOnce, TraversableOnce}
99
import scala.collection.compat._
1010
import scala.collection.mutable.ArrayBuffer
1111
import scala.reflect.ClassTag
12-
import scala.util.control.NonFatal
12+
import scala.util.control.{ControlThrowable, NonFatal}
1313

1414
import java.io._
1515
import java.lang.reflect.Method
@@ -1038,6 +1038,23 @@ package object utils
10381038

10391039
def jsonToBytes(v: JValue): Array[Byte] =
10401040
JsonMethods.compact(v).getBytes(StandardCharsets.UTF_8)
1041+
1042+
private[this] object Retry extends ControlThrowable
1043+
1044+
def retry[A]: A = throw Retry
1045+
1046+
def retryable[A](f: Int => A): A = {
1047+
var attempts: Int = 0
1048+
1049+
while (true)
1050+
try return f(attempts)
1051+
catch {
1052+
case Retry =>
1053+
attempts += 1
1054+
}
1055+
1056+
uninitialized
1057+
}
10411058
}
10421059

10431060
class CancellingExecutorService(delegate: ExecutorService) extends AbstractExecutorService {

0 commit comments

Comments
 (0)