Skip to content

Commit 892768c

Browse files
committed
[query] run ServiceBackend in Py4jQueryDriver
1 parent 72e7d4b commit 892768c

File tree

10 files changed

+353
-59
lines changed

10 files changed

+353
-59
lines changed

hail/hail/src/is/hail/backend/service/ServiceBackend.scala

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package is.hail.backend.service
22

3+
import is.hail.HAIL_REVISION
34
import is.hail.backend._
45
import is.hail.backend.Backend.PartitionFn
56
import is.hail.backend.local.LocalTaskContext
@@ -12,6 +13,7 @@ import is.hail.expr.ir.analyses.SemanticHash
1213
import is.hail.expr.ir.lowering._
1314
import is.hail.services._
1415
import is.hail.services.JobGroupStates.{Cancelled, Failure, Success}
16+
import is.hail.services.oauth2.{CloudCredentials, HailCredentials}
1517
import is.hail.types._
1618
import is.hail.types.physical._
1719
import is.hail.utils._
@@ -26,8 +28,65 @@ import scala.util.control.NonFatal
2628
import java.io._
2729
import java.util.concurrent.Executors
2830

31+
import com.fasterxml.jackson.core.StreamReadConstraints
32+
2933
object ServiceBackend {
3034
val MaxAvailableGcsConnections = 1000
35+
36+
// See https://github.com/hail-is/hail/issues/14580
37+
StreamReadConstraints.overrideDefaultStreamReadConstraints(
38+
StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build()
39+
)
40+
41+
def pyServiceBackend(
42+
name: String,
43+
batchId_ : Integer,
44+
billingProject: String,
45+
deployConfigFile: String,
46+
workerCores: String,
47+
workerMemory: String,
48+
storage: String,
49+
cloudfuse: Array[CloudfuseConfig],
50+
regions: Array[String],
51+
): ServiceBackend = {
52+
val credentials: CloudCredentials =
53+
HailCredentials().getOrElse(CloudCredentials(keyPath = None))
54+
55+
val client =
56+
BatchClient(
57+
DeployConfig.fromConfigFile(deployConfigFile),
58+
credentials,
59+
)
60+
61+
val batchId =
62+
Option(batchId_).map(_.toInt).getOrElse {
63+
client.newBatch(
64+
BatchRequest(
65+
billing_project = billingProject,
66+
token = tokenUrlSafe,
67+
n_jobs = 0,
68+
attributes = Map("name" -> name),
69+
)
70+
)
71+
}
72+
73+
val workerConfig =
74+
BatchJobConfig(
75+
workerCores,
76+
workerMemory,
77+
storage,
78+
cloudfuse,
79+
regions,
80+
)
81+
82+
new ServiceBackend(
83+
name,
84+
client,
85+
GitRevision(HAIL_REVISION),
86+
BatchConfig(batchId, 0),
87+
workerConfig,
88+
)
89+
}
3190
}
3291

3392
case class BatchJobConfig(
@@ -122,10 +181,16 @@ class ServiceBackend(
122181
)
123182

124183
stageCount += 1
125-
126-
Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms
127-
val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
128-
(response, startJobId)
184+
try {
185+
Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms
186+
val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
187+
(response, startJobId)
188+
} catch {
189+
case _: InterruptedException =>
190+
batchClient.cancelJobGroup(batchConfig.batchId, jobGroupId)
191+
Thread.currentThread().interrupt()
192+
throw new CancellationException()
193+
}
129194
}
130195

131196
override def mapCollectPartitions(

hail/hail/src/is/hail/services/BatchClient.scala

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import is.hail.services.BatchClient.{
77
}
88
import is.hail.services.JobGroupStates.isTerminal
99
import is.hail.services.oauth2.CloudCredentials
10-
import is.hail.services.requests.Requester
10+
import is.hail.services.requests.{ClientResponseException, Requester}
1111
import is.hail.utils._
1212

1313
import scala.collection.compat.immutable.LazyList
@@ -270,22 +270,33 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
270270
batchId
271271
}
272272

273-
def newJobGroup(req: JobGroupRequest): (Int, Int) = {
274-
val nJobs = req.jobs.length
275-
val (updateId, startJobGroupId, startJobId) = beginUpdate(req.batch_id, req.token, nJobs)
276-
log.info(s"Began update '$updateId' for batch '${req.batch_id}'.")
277-
278-
createJobGroup(updateId, req)
279-
log.info(s"Created job group $startJobGroupId for batch ${req.batch_id}")
280-
281-
createJobsIncremental(req.batch_id, updateId, req.jobs)
282-
log.info(s"Submitted $nJobs in job group $startJobGroupId for batch ${req.batch_id}")
283-
284-
commitUpdate(req.batch_id, updateId)
285-
log.info(s"Committed update $updateId for batch ${req.batch_id}.")
286-
287-
(startJobGroupId, startJobId)
288-
}
273+
def newJobGroup(req: JobGroupRequest): (Int, Int) =
274+
retryable { attempts =>
275+
try {
276+
val nJobs = req.jobs.length
277+
val (updateId, startJobGroupId, startJobId) = beginUpdate(req.batch_id, req.token, nJobs)
278+
log.info(s"Began update '$updateId' for batch '${req.batch_id}'.")
279+
280+
createJobGroup(updateId, req)
281+
log.info(s"Created job group $startJobGroupId for batch ${req.batch_id}")
282+
283+
createJobsIncremental(req.batch_id, updateId, req.jobs)
284+
log.info(s"Submitted $nJobs in job group $startJobGroupId for batch ${req.batch_id}")
285+
286+
commitUpdate(req.batch_id, updateId)
287+
log.info(s"Committed update $updateId for batch ${req.batch_id}.")
288+
289+
(startJobGroupId, startJobId)
290+
} catch {
291+
case e: ClientResponseException
292+
if e.status == 400
293+
&& e.getMessage.contains("job group specs were not submitted in order") =>
294+
val delay = delayMsForTry(Random.nextInt(attempts + 1))
295+
log.warn(f"Another process is updating batch ${req.batch_id} - retrying in $delay ms.", e)
296+
Thread.sleep(delay)
297+
retry
298+
}
299+
}
289300

290301
def getJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse =
291302
req
@@ -338,7 +349,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
338349
Thread.sleep(d.toLong)
339350
}
340351

341-
throw new AssertionError("unreachable")
352+
unreachable
342353
}
343354

344355
override def close(): Unit =

hail/hail/src/is/hail/services/package.scala

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import is.hail.services.requests.ClientResponseException
44
import is.hail.shadedazure.com.azure.storage.common.implementation.Constants
55
import is.hail.utils._
66

7+
import scala.annotation.tailrec
78
import scala.util.Random
89

910
import java.io._
@@ -30,31 +31,24 @@ package object services {
3031

3132
private[this] val LOG_2_MAX_MULTIPLIER =
3233
30 // do not set larger than 30 due to integer overflow calculating multiplier
33-
private[this] val DEFAULT_MAX_DELAY_MS = 60000
34-
private[this] val DEFAULT_BASE_DELAY_MS = 1000
34+
private[this] val DEFAULT_MAX_DELAY_MS = 60000L
35+
private[this] val DEFAULT_BASE_DELAY_MS = 1000L
3536

3637
def delayMsForTry(
3738
tries: Int,
38-
baseDelayMs: Int = DEFAULT_BASE_DELAY_MS,
39-
maxDelayMs: Int = DEFAULT_MAX_DELAY_MS,
40-
): Int = {
39+
baseDelayMs: Long = DEFAULT_BASE_DELAY_MS,
40+
maxDelayMs: Long = DEFAULT_MAX_DELAY_MS,
41+
): Long = {
4142
// Based on AWS' recommendations:
4243
// - https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
4344
/* -
4445
* https://github.com/aws/aws-sdk-java/blob/master/aws-java-sdk-core/src/main/java/com/amazonaws/retry/PredefinedBackoffStrategies.java */
4546
val multiplier = 1L << math.min(tries, LOG_2_MAX_MULTIPLIER)
46-
val ceilingForDelayMs = math.min(baseDelayMs * multiplier, maxDelayMs.toLong).toInt
47-
val proposedDelayMs = ceilingForDelayMs / 2 + Random.nextInt(ceilingForDelayMs / 2 + 1)
48-
return proposedDelayMs
47+
val ceilingForDelayMs = math.min(baseDelayMs * multiplier, maxDelayMs).toInt
48+
ceilingForDelayMs / 2L + Random.nextInt(ceilingForDelayMs / 2 + 1)
4949
}
5050

51-
def sleepBeforTry(
52-
tries: Int,
53-
baseDelayMs: Int = DEFAULT_BASE_DELAY_MS,
54-
maxDelayMs: Int = DEFAULT_MAX_DELAY_MS,
55-
) =
56-
Thread.sleep(delayMsForTry(tries, baseDelayMs, maxDelayMs).toLong)
57-
51+
@tailrec
5852
def isLimitedRetriesError(_e: Throwable): Boolean = {
5953
// An exception is a "retry once error" if a rare, known bug in a dependency or in a cloud
6054
// provider can manifest as this exception *and* that manifestation is indistinguishable from a
@@ -94,6 +88,7 @@ package object services {
9488
}
9589
}
9690

91+
@tailrec
9792
def isTransientError(_e: Throwable): Boolean = {
9893
// ReactiveException is package private inside reactore.core.Exception so we cannot access
9994
// it directly for an isInstance check. AFAICT, this is the only way to check if we received
@@ -185,14 +180,11 @@ package object services {
185180
}
186181
}
187182

188-
def retryTransientErrors[T](f: => T, reset: Option[() => Unit] = None): T = {
189-
var tries = 0
190-
while (true) {
191-
try
192-
return f
183+
def retryTransientErrors[T](f: => T, reset: Option[() => Unit] = None): T =
184+
retryable { tries =>
185+
try f
193186
catch {
194187
case e: Exception =>
195-
tries += 1
196188
val delay = delayMsForTry(tries)
197189
if (tries <= 5 && isLimitedRetriesError(e)) {
198190
log.warn(
@@ -205,14 +197,12 @@ package object services {
205197
} else if (tries % 10 == 0) {
206198
log.warn(s"Encountered $tries transient errors, most recent one was $e.")
207199
}
208-
Thread.sleep(delay.toLong)
200+
Thread.sleep(delay)
201+
reset.foreach(_())
202+
retry
209203
}
210-
reset.foreach(_())
211204
}
212205

213-
throw new AssertionError("unreachable")
214-
}
215-
216206
def formatException(e: Throwable): String = {
217207
using(new StringWriter()) { sw =>
218208
using(new PrintWriter(sw)) { pw =>

hail/hail/src/is/hail/utils/package.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import scala.collection.compat._
99
import scala.collection.mutable
1010
import scala.collection.mutable.ArrayBuffer
1111
import scala.reflect.ClassTag
12-
import scala.util.control.NonFatal
12+
import scala.util.control.{ControlThrowable, NonFatal}
1313

1414
import java.io._
1515
import java.lang.reflect.Method
@@ -379,6 +379,8 @@ package object utils
379379

380380
def uninitialized[T]: T = null.asInstanceOf[T]
381381

382+
def unreachable[A]: A = throw new AssertionError("unreachable")
383+
382384
private object mapAccumulateInstance extends MapAccumulate[Nothing, Nothing]
383385

384386
def mapAccumulate[C[_], U] =
@@ -958,6 +960,23 @@ package object utils
958960

959961
def jsonToBytes(v: JValue): Array[Byte] =
960962
JsonMethods.compact(v).getBytes(StandardCharsets.UTF_8)
963+
964+
private[this] object Retry extends ControlThrowable
965+
966+
def retry[A]: A = throw Retry
967+
968+
def retryable[A](f: Int => A): A = {
969+
var attempts: Int = 0
970+
971+
while (true)
972+
try return f(attempts)
973+
catch {
974+
case Retry =>
975+
attempts += 1
976+
}
977+
978+
unreachable
979+
}
961980
}
962981

963982
class CancellingExecutorService(delegate: ExecutorService) extends AbstractExecutorService {

hail/python/hail/context.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,26 +366,47 @@ def init(
366366
backend = 'batch'
367367

368368
if backend == 'batch':
369-
return hail_event_loop().run_until_complete(
370-
init_batch(
369+
if os.getenv('HAIL_QUERY_USE_LOCAL_DRIVER') is not None:
370+
return hail.experimental.init(
371+
backend=backend,
372+
app_name=app_name,
371373
log=log,
372374
quiet=quiet,
373375
append=append,
374-
tmpdir=tmp_dir,
376+
tmp_dir=tmp_dir,
375377
default_reference=default_reference,
376378
global_seed=global_seed,
377379
driver_cores=driver_cores,
378380
driver_memory=driver_memory,
379381
worker_cores=worker_cores,
380382
worker_memory=worker_memory,
381383
batch_id=batch_id,
382-
name_prefix=app_name,
383384
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
384385
regions=regions,
385386
gcs_bucket_allow_list=gcs_bucket_allow_list,
386387
branching_factor=branching_factor,
387388
)
388-
)
389+
else:
390+
return hail_event_loop().run_until_complete(
391+
init_batch(
392+
log=log,
393+
quiet=quiet,
394+
append=append,
395+
tmpdir=tmp_dir,
396+
default_reference=default_reference,
397+
global_seed=global_seed,
398+
driver_cores=driver_cores,
399+
driver_memory=driver_memory,
400+
worker_cores=worker_cores,
401+
worker_memory=worker_memory,
402+
batch_id=batch_id,
403+
name_prefix=app_name,
404+
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
405+
regions=regions,
406+
gcs_bucket_allow_list=gcs_bucket_allow_list,
407+
branching_factor=branching_factor,
408+
)
409+
)
389410
if backend == 'spark':
390411
return init_spark(
391412
sc=sc,

hail/python/hail/experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .context import init
12
from .datasets import load_dataset
23
from .db import DB
34
from .export_entries_by_col import export_entries_by_col
@@ -35,6 +36,7 @@
3536
'hail_metadata',
3637
'haplotype_freq_em',
3738
'import_gtf',
39+
'init',
3840
'ld_score',
3941
'ld_score_regression',
4042
'load_dataset',

0 commit comments

Comments
 (0)