Skip to content

Commit c5b8c2d

Browse files
committed
[query] Remove BlockMatrix persist from Backend interface
1 parent 901fafe commit c5b8c2d

File tree

12 files changed

+93
-94
lines changed

12 files changed

+93
-94
lines changed

hail/src/main/scala/is/hail/backend/Backend.scala

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ import is.hail.io.{BufferSpec, TypedCodecSpec}
1212
import is.hail.io.fs._
1313
import is.hail.io.plink.LoadPlink
1414
import is.hail.io.vcf.LoadVCF
15-
import is.hail.linalg.BlockMatrix
1615
import is.hail.types._
1716
import is.hail.types.encoded.EType
1817
import is.hail.types.physical.PTuple
19-
import is.hail.types.virtual.{BlockMatrixType, TFloat64}
18+
import is.hail.types.virtual.TFloat64
2019
import is.hail.utils._
2120
import is.hail.variant.ReferenceGenome
2221

@@ -77,15 +76,6 @@ abstract class Backend extends Closeable {
7776

7877
def broadcast[T: ClassTag](value: T): BroadcastValue[T]
7978

80-
def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
81-
: Unit
82-
83-
def unpersist(backendContext: BackendContext, id: String): Unit
84-
85-
def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix
86-
87-
def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType
88-
8979
def parallelizeAndComputeWithIndex(
9080
backendContext: BackendContext,
9181
fs: FS,

hail/src/main/scala/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import is.hail.asm4s.HailClassLoader
66
import is.hail.backend.local.LocalTaskContext
77
import is.hail.expr.ir.lowering.IrMetadata
88
import is.hail.io.fs.FS
9+
import is.hail.linalg.BlockMatrix
910
import is.hail.utils._
1011
import is.hail.variant.ReferenceGenome
1112

@@ -71,6 +72,7 @@ object ExecuteContext {
7172
flags: HailFeatureFlags,
7273
backendContext: BackendContext,
7374
irMetadata: IrMetadata,
75+
blockMatrixCache: mutable.Map[String, BlockMatrix],
7476
)(
7577
f: ExecuteContext => T
7678
): T = {
@@ -89,6 +91,7 @@ object ExecuteContext {
8991
flags,
9092
backendContext,
9193
irMetadata,
94+
blockMatrixCache,
9295
))(f(_))
9396
}
9497
}
@@ -118,6 +121,7 @@ class ExecuteContext(
118121
val flags: HailFeatureFlags,
119122
val backendContext: BackendContext,
120123
val irMetadata: IrMetadata,
124+
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
121125
) extends Closeable {
122126

123127
val rngNonce: Long =
@@ -184,6 +188,7 @@ class ExecuteContext(
184188
flags: HailFeatureFlags = this.flags,
185189
backendContext: BackendContext = this.backendContext,
186190
irMetadata: IrMetadata = this.irMetadata,
191+
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
187192
)(
188193
f: ExecuteContext => A
189194
): A =
@@ -200,5 +205,6 @@ class ExecuteContext(
200205
flags,
201206
backendContext,
202207
irMetadata,
208+
blockMatrixCache,
203209
))(f)
204210
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package is.hail.backend.caching
2+
3+
import is.hail.linalg.BlockMatrix
4+
5+
import scala.collection.mutable
6+
7+
class BlockMatrixCache extends mutable.AbstractMap[String, BlockMatrix] with AutoCloseable {
8+
9+
private[this] val blockmatrices: mutable.Map[String, BlockMatrix] =
10+
mutable.LinkedHashMap.empty
11+
12+
override def +=(kv: (String, BlockMatrix)): BlockMatrixCache.this.type = {
13+
blockmatrices += kv; this
14+
}
15+
16+
override def -=(key: String): BlockMatrixCache.this.type = {
17+
get(key).foreach { bm => bm.unpersist(); blockmatrices -= key }; this
18+
}
19+
20+
override def get(key: String): Option[BlockMatrix] =
21+
blockmatrices.get(key)
22+
23+
override def iterator: Iterator[(String, BlockMatrix)] =
24+
blockmatrices.iterator
25+
26+
override def clear(): Unit = {
27+
blockmatrices.values.foreach(_.unpersist())
28+
blockmatrices.clear()
29+
}
30+
31+
override def close(): Unit =
32+
clear()
33+
}

hail/src/main/scala/is/hail/backend/local/LocalBackend.scala

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@ import is.hail.expr.ir._
1010
import is.hail.expr.ir.analyses.SemanticHash
1111
import is.hail.expr.ir.lowering._
1212
import is.hail.io.fs._
13-
import is.hail.linalg.BlockMatrix
1413
import is.hail.types._
1514
import is.hail.types.physical.PTuple
1615
import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType
17-
import is.hail.types.virtual.{BlockMatrixType, TVoid}
16+
import is.hail.types.virtual.TVoid
1817
import is.hail.utils._
1918
import is.hail.variant.ReferenceGenome
2019

@@ -113,6 +112,7 @@ class LocalBackend(
113112
ExecutionCache.fromFlags(flags, fs, tmpdir)
114113
},
115114
new IrMetadata(),
115+
ImmutableMap.empty,
116116
)(f)
117117
}
118118

@@ -215,15 +215,6 @@ class LocalBackend(
215215
): TableReader =
216216
LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt, nPartitions)
217217

218-
def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
219-
: Unit = ???
220-
221-
def unpersist(backendContext: BackendContext, id: String): Unit = ???
222-
223-
def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = ???
224-
225-
def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ???
226-
227218
def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
228219
: TableStage =
229220
LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses)

hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import is.hail.expr.ir.functions.IRFunctionRegistry
1414
import is.hail.expr.ir.lowering._
1515
import is.hail.io.fs._
1616
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
17-
import is.hail.linalg.BlockMatrix
1817
import is.hail.services.{BatchClient, JobGroupRequest, _}
1918
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
2019
import is.hail.types._
@@ -375,15 +374,6 @@ class ServiceBackend(
375374
): TableReader =
376375
LowerDistributedSort.distributedSort(ctx, inputStage, sortFields, rt, nPartitions)
377376

378-
def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
379-
: Unit = ???
380-
381-
def unpersist(backendContext: BackendContext, id: String): Unit = ???
382-
383-
def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = ???
384-
385-
def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ???
386-
387377
def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
388378
: TableStage =
389379
LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses)
@@ -402,6 +392,7 @@ class ServiceBackend(
402392
flags,
403393
serviceBackendContext,
404394
new IrMetadata(),
395+
ImmutableMap.empty,
405396
)(f)
406397
}
407398

hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ import is.hail.{HailContext, HailFeatureFlags}
44
import is.hail.annotations._
55
import is.hail.asm4s._
66
import is.hail.backend._
7+
import is.hail.backend.caching.BlockMatrixCache
78
import is.hail.backend.py4j.Py4JBackendExtensions
89
import is.hail.expr.Validate
910
import is.hail.expr.ir._
1011
import is.hail.expr.ir.analyses.SemanticHash
1112
import is.hail.expr.ir.lowering._
1213
import is.hail.io.{BufferSpec, TypedCodecSpec}
1314
import is.hail.io.fs._
14-
import is.hail.linalg.BlockMatrix
1515
import is.hail.rvd.RVD
1616
import is.hail.types._
1717
import is.hail.types.physical.{PStruct, PTuple}
@@ -351,20 +351,8 @@ class SparkBackend(
351351
override val longLifeTempFileManager: TempFileManager =
352352
new OwningTempFileManager(fs)
353353

354-
val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache()
355-
356-
def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
357-
: Unit = bmCache.persistBlockMatrix(id, value, storageLevel)
358-
359-
def unpersist(backendContext: BackendContext, id: String): Unit = unpersist(id)
360-
361-
def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix =
362-
bmCache.getPersistedBlockMatrix(id)
363-
364-
def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType =
365-
bmCache.getPersistedBlockMatrixType(id)
366-
367-
def unpersist(id: String): Unit = bmCache.unpersistBlockMatrix(id)
354+
private[this] val bmCache: BlockMatrixCache =
355+
new BlockMatrixCache()
368356

369357
def createExecuteContextForTests(
370358
timer: ExecutionTimer,
@@ -387,6 +375,7 @@ class SparkBackend(
387375
ExecutionCache.forTesting
388376
},
389377
new IrMetadata(),
378+
ImmutableMap.empty,
390379
)
391380

392381
override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
@@ -406,6 +395,7 @@ class SparkBackend(
406395
ExecutionCache.fromFlags(flags, fs, tmpdir)
407396
},
408397
new IrMetadata(),
398+
bmCache,
409399
)(f)
410400
}
411401

@@ -470,6 +460,7 @@ class SparkBackend(
470460
override def asSpark(op: String): SparkBackend = this
471461

472462
def close(): Unit = {
463+
bmCache.close()
473464
SparkBackend.stop()
474465
longLifeTempFileManager.close()
475466
}

hail/src/main/scala/is/hail/backend/spark/SparkBlockMatrixCache.scala

Lines changed: 0 additions & 25 deletions
This file was deleted.

hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package is.hail.expr.ir
22

3-
import is.hail.HailContext
43
import is.hail.annotations.NDArray
5-
import is.hail.backend.{BackendContext, ExecuteContext}
4+
import is.hail.backend.ExecuteContext
65
import is.hail.expr.Nat
76
import is.hail.expr.ir.lowering.{BMSContexts, BlockMatrixStage2, LowererUnsupportedOperation}
87
import is.hail.io.{StreamBufferSpec, TypedCodecSpec}
@@ -106,7 +105,7 @@ object BlockMatrixReader {
106105
def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixReader =
107106
(jv \ "name").extract[String] match {
108107
case "BlockMatrixNativeReader" => BlockMatrixNativeReader.fromJValue(ctx.fs, jv)
109-
case "BlockMatrixPersistReader" => BlockMatrixPersistReader.fromJValue(ctx.backendContext, jv)
108+
case "BlockMatrixPersistReader" => BlockMatrixPersistReader.fromJValue(ctx, jv)
110109
case _ => jv.extract[BlockMatrixReader]
111110
}
112111
}
@@ -274,22 +273,20 @@ case class BlockMatrixBinaryReader(path: String, shape: IndexedSeq[Long], blockS
274273
case class BlockMatrixNativePersistParameters(id: String)
275274

276275
object BlockMatrixPersistReader {
277-
def fromJValue(ctx: BackendContext, jv: JValue): BlockMatrixPersistReader = {
276+
def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixPersistReader = {
278277
implicit val formats: Formats = BlockMatrixReader.formats
279278
val params = jv.extract[BlockMatrixNativePersistParameters]
280279
BlockMatrixPersistReader(
281280
params.id,
282-
HailContext.backend.getPersistedBlockMatrixType(ctx, params.id),
281+
BlockMatrixType.fromBlockMatrix(ctx.BlockMatrixCache(params.id)),
283282
)
284283
}
285284
}
286285

287286
case class BlockMatrixPersistReader(id: String, typ: BlockMatrixType) extends BlockMatrixReader {
288287
def pathsUsed: Seq[String] = FastSeq()
289288
lazy val fullType: BlockMatrixType = typ
290-
291-
def apply(ctx: ExecuteContext): BlockMatrix =
292-
HailContext.backend.getPersistedBlockMatrix(ctx.backendContext, id)
289+
def apply(ctx: ExecuteContext): BlockMatrix = ctx.BlockMatrixCache(id)
293290
}
294291

295292
case class BlockMatrixMap(child: BlockMatrixIR, eltName: Name, f: IR, needsDense: Boolean)

hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package is.hail.expr.ir
22

3-
import is.hail.HailContext
43
import is.hail.annotations.Region
54
import is.hail.asm4s._
65
import is.hail.backend.ExecuteContext
@@ -190,7 +189,7 @@ case class BlockMatrixPersistWriter(id: String, storageLevel: String) extends Bl
190189
def pathOpt: Option[String] = None
191190

192191
def apply(ctx: ExecuteContext, bm: BlockMatrix): Unit =
193-
HailContext.backend.persist(ctx.backendContext, id, bm, storageLevel)
192+
ctx.BlockMatrixCache += id -> bm.persist(storageLevel)
194193

195194
def loweredTyp: Type = TVoid
196195
}

hail/src/main/scala/is/hail/utils/ErrorHandling.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class HailWorkerException(
1818
trait ErrorHandling {
1919
def fatal(msg: String): Nothing = throw new HailException(msg)
2020

21-
def fatal(msg: String, errorId: Int) = throw new HailException(msg, errorId)
21+
def fatal(msg: String, errorId: Int): Nothing = throw new HailException(msg, errorId)
2222

2323
def fatal(msg: String, cause: Throwable): Nothing = throw new HailException(msg, None, cause)
2424

0 commit comments

Comments
 (0)