Skip to content

Commit ae2bc2a

Browse files
Refactor BlockMatrixSparsity
1 parent 052c838 commit ae2bc2a

File tree

18 files changed

+1107
-923
lines changed

18 files changed

+1107
-923
lines changed

hail/hail/src-2.12/is/hail/utils/compat/immutable/ArraySeq.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,7 @@ object ArraySeq {
2323

2424
implicit def canBuildFrom[T: ClassTag]: CanBuildFrom[ArraySeq[_], T, ArraySeq[T]] =
2525
A.canBuildFrom
26+
27+
def tabulate[T: ClassTag](n: Int)(f: Int => T): ArraySeq[T] =
28+
A.unsafeWrapArray(Array.tabulate(n)(f))
2629
}

hail/hail/src-2.12/is/hail/utils/compat/package.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,20 @@ package is.hail.utils
33
import is.hail.utils.compat.immutable.ArraySeq
44

55
import scala.collection.compat.Factory
6+
import scala.collection.mutable
67
import scala.reflect.ClassTag
78

89
package object compat {
910
implicit def arraySeqbf[A: ClassTag](ob: ArraySeq.type): Factory[A, ArraySeq[A]] =
1011
ob.canBuildFrom[A]
12+
13+
implicit class ArrayOps[A](private val a: Array[A]) extends AnyVal {
14+
def sortInPlace[B >: A]()(implicit ct: ClassTag[B], ord: Ordering[B]): mutable.WrappedArray[A] = {
15+
scala.util.Sorting.stableSort(a.asInstanceOf[Array[B]])
16+
mutable.WrappedArray.make(a)
17+
}
18+
19+
def sortInPlaceBy[B](f: A => B)(implicit ord: Ordering[B], ct: ClassTag[A]): mutable.WrappedArray[A] =
20+
sortInPlace()(ct, ord on f)
21+
}
1122
}

hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala

Lines changed: 79 additions & 80 deletions
Large diffs are not rendered by default.

hail/hail/src/is/hail/expr/ir/BlockMatrixWriter.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import is.hail.expr.ir.defs.{MetadataWriter, Str, UUID4, WriteMetadata, WriteVal
88
import is.hail.expr.ir.lowering.{BlockMatrixStage2, LowererUnsupportedOperation}
99
import is.hail.io.{StreamBufferSpec, TypedCodecSpec}
1010
import is.hail.io.fs.FS
11-
import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata}
11+
import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata, MatrixSparsity}
1212
import is.hail.types.TypeWithRequiredness
1313
import is.hail.types.encoded.{EBlockMatrixNDArray, ENumpyBinaryNDArray, EType}
1414
import is.hail.types.virtual._
@@ -135,8 +135,11 @@ case class BlockMatrixNativeMetadataWriter(
135135
cb: EmitCodeBuilder,
136136
region: Value[Region],
137137
): Unit = {
138-
val metaHelper =
139-
BMMetadataHelper(path, typ.blockSize, typ.nRows, typ.nCols, typ.linearizedDefinedBlocks)
138+
val partIdxToBlockIdx = typ.sparsity match {
139+
case _: MatrixSparsity.Dense => None
140+
case x: MatrixSparsity.Sparse => Some(x.definedBlocksColMajorLinear)
141+
}
142+
val metaHelper = BMMetadataHelper(path, typ.blockSize, typ.nRows, typ.nCols, partIdxToBlockIdx)
140143

141144
val pc = writeAnnotations.getOrFatal(cb, "write annotations can't be missing!").asIndexable
142145
val partFiles = cb.newLocal[Array[String]]("partFiles")

hail/hail/src/is/hail/expr/ir/IR.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,8 @@ package defs {
512512
def <=(other: IR): IR = ApplyComparisonOp(LTEQ, self, other)
513513

514514
def >=(other: IR): IR = ApplyComparisonOp(GTEQ, self, other)
515+
516+
def log(messages: AnyRef*): IR = logIR(self, messages: _*)
515517
}
516518

517519
object ErrorIDs {

hail/hail/src/is/hail/expr/ir/MatrixWriter.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import is.hail.io.gen.{BgenWriter, ExportGen}
1414
import is.hail.io.index.StagedIndexWriter
1515
import is.hail.io.plink.{BitPacker, ExportPlink}
1616
import is.hail.io.vcf.{ExportVCF, TabixVCF}
17-
import is.hail.linalg.BlockMatrix
17+
import is.hail.linalg.{BlockMatrix, MatrixSparsity}
1818
import is.hail.rvd.{IndexSpec, RVDPartitioner, RVDSpecMaker}
1919
import is.hail.types._
2020
import is.hail.types.encoded.{EBaseStruct, EBlockMatrixNDArray, EType}
@@ -2340,7 +2340,7 @@ case class MatrixBlockMatrixWriter(
23402340

23412341
val countColumnsIR = ArrayLen(GetField(ts.getGlobals(), colsFieldName))
23422342
val numCols: Int = CompileAndEvaluate[Int](ctx, countColumnsIR)
2343-
val numBlockCols: Int = (numCols - 1) / blockSize + 1
2343+
val numBlockCols: Int = BlockMatrixType.numBlocks(numCols.toLong, blockSize)
23442344
val lastBlockNumCols = (numCols - 1) % blockSize + 1
23452345

23462346
val rowCountIR = ts.mapCollect("matrix_block_matrix_writer_partition_counts")(paritionIR =>
@@ -2353,7 +2353,7 @@ case class MatrixBlockMatrixWriter(
23532353
val inputPartStops = inputPartStartsPlusLast.tail
23542354

23552355
val numRows = inputPartStartsPlusLast.last
2356-
val numBlockRows: Int = (numRows.toInt - 1) / blockSize + 1
2356+
val numBlockRows: Int = BlockMatrixType.numBlocks(numRows, blockSize)
23572357

23582358
// Zip contexts with partition starts and ends
23592359
val zippedWithStarts = ts.mapContexts { oldContextsStream =>
@@ -2510,7 +2510,7 @@ case class MatrixBlockMatrixWriter(
25102510
numRows,
25112511
numCols.toLong,
25122512
blockSize,
2513-
BlockMatrixSparsity.dense,
2513+
MatrixSparsity.dense(numBlockRows, numBlockCols),
25142514
)
25152515
RelationalWriter.scoped(path, overwrite, None)(WriteMetadata(
25162516
flatPaths,

hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,48 @@ object ArrayFunctions extends RegistryFunctions {
311311
ToArray(flatMapIR(ToStream(a))(ToStream(_)))
312312
}
313313

314+
registerSCode3t(
315+
"scatter",
316+
Array(tv("T")),
317+
TArray(tv("T")),
318+
TArray(TInt32),
319+
TInt32,
320+
TArray(tv("T")),
321+
(_, a, _, _) => PCanonicalArray(a.asInstanceOf[SContainer].elementType.storageType()).sType,
322+
) {
323+
case (
324+
er,
325+
cb,
326+
_,
327+
rt: SIndexablePointer,
328+
elts: SIndexableValue,
329+
indices: SIndexableValue,
330+
len: SInt32Value,
331+
errorID,
332+
) =>
333+
val pt = rt.pType.asInstanceOf[PCanonicalArray]
334+
val (push, finish) =
335+
pt.constructFromIndicesUnsafe(cb, er.region, len.value, deepCopy = false)
336+
cb.if_(
337+
elts.loadLength.cne(indices.loadLength),
338+
cb._fatalWithError(errorID, "scatter: values and indices arrays have different lengths"),
339+
)
340+
indices.forEachDefined(cb) { case (cb, pos, idx: SInt32Value) =>
341+
cb.if_(
342+
idx.value < 0 || idx.value >= len.value,
343+
cb._fatalWithError(
344+
errorID,
345+
"scatter: indices array contained index ",
346+
idx.value.toS,
347+
", which is greater than result length ",
348+
len.value.toS,
349+
),
350+
)
351+
push(cb, idx.value, elts.loadElement(cb, pos))
352+
}
353+
finish(cb)
354+
}
355+
314356
registerSCode4(
315357
"lowerBound",
316358
TArray(tv("T")),

hail/hail/src/is/hail/expr/ir/functions/Functions.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,31 @@ abstract class RegistryFunctions {
816816
case (r, cb, _, rt, Array(a1, a2, a3), errorID) => impl(r, cb, rt, a1, a2, a3, errorID)
817817
}
818818

819+
def registerSCode3t(
820+
name: String,
821+
typeParams: Array[Type],
822+
mt1: Type,
823+
mt2: Type,
824+
mt3: Type,
825+
rt: Type,
826+
pt: (Type, SType, SType, SType) => SType,
827+
)(
828+
impl: (
829+
EmitRegion,
830+
EmitCodeBuilder,
831+
Seq[Type],
832+
SType,
833+
SValue,
834+
SValue,
835+
SValue,
836+
Value[Int],
837+
) => SValue
838+
): Unit =
839+
registerSCode(name, Array(mt1, mt2, mt3), rt, unwrappedApply(pt), typeParams) {
840+
case (r, cb, typeParams, rt, Array(a1, a2, a3), errorID) =>
841+
impl(r, cb, typeParams, rt, a1, a2, a3, errorID)
842+
}
843+
819844
def registerSCode4(
820845
name: String,
821846
mt1: Type,

0 commit comments

Comments
 (0)