From 1b6a93c9eaadb49083b23658eea2024e67db4278 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 21 Jul 2025 15:27:47 -0400 Subject: [PATCH] [query] fix TableValue.mapRows ignored branches --- hail/hail/src/is/hail/HailFeatureFlags.scala | 1 - .../hail/src/is/hail/expr/ir/TableValue.scala | 374 +----------------- .../expr/ir/lowering/ExecuteRelational.scala | 5 +- hail/python/hail/backend/backend.py | 1 - hail/python/test/hail/expr/test_expr.py | 3 +- 5 files changed, 4 insertions(+), 380 deletions(-) diff --git a/hail/hail/src/is/hail/HailFeatureFlags.scala b/hail/hail/src/is/hail/HailFeatureFlags.scala index 49eff3139ec..18ecb1bf4d7 100644 --- a/hail/hail/src/is/hail/HailFeatureFlags.scala +++ b/hail/hail/src/is/hail/HailFeatureFlags.scala @@ -16,7 +16,6 @@ object HailFeatureFlags { // // The default values and envvars here are only used in the Scala tests. In all other // conditions, Python initializes the flags, see HailContext._initialize_flags in context.py. - ("distributed_scan_comb_op", ("HAIL_DEV_DISTRIBUTED_SCAN_COMB_OP" -> null)), ("grouped_aggregate_buffer_size", ("HAIL_GROUPED_AGGREGATE_BUFFER_SIZE" -> "50")), ("index_branching_factor", "HAIL_INDEX_BRANCHING_FACTOR" -> null), ("jvm_bytecode_dump", ("HAIL_DEV_JVM_BYTECODE_DUMP" -> null)), diff --git a/hail/hail/src/is/hail/expr/ir/TableValue.scala b/hail/hail/src/is/hail/expr/ir/TableValue.scala index 89b7094e62a..5a458c30abc 100644 --- a/hail/hail/src/is/hail/expr/ir/TableValue.scala +++ b/hail/hail/src/is/hail/expr/ir/TableValue.scala @@ -4,7 +4,7 @@ import is.hail.HailContext import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.backend.spark.{SparkBackend, SparkTaskContext} +import is.hail.backend.spark.SparkTaskContext import is.hail.expr.TableAnnotationImpex import is.hail.expr.ir.agg.Aggs import is.hail.expr.ir.compile.{Compile, CompileWithAggregators} @@ -27,9 +27,6 @@ import is.hail.utils._ import scala.reflect.ClassTag -import java.io.{DataInputStream, DataOutputStream} - -import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.StructType @@ -1053,375 +1050,6 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow ) } - def mapRows(extracted: Aggs): TableValue = { - val fsBc = ctx.fsBc - val newType = typ.copy(rowType = extracted.postAggIR.typ.asInstanceOf[TStruct]) - - if (extracted.aggs.isEmpty) { - val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = - Compile[AsmFunction3RegionLongLongLong]( - ctx, - FastSeq( - ( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)), - ), - ( - TableIR.rowName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)), - ), - ), - FastSeq(classInfo[Region], LongInfo, LongInfo), - LongInfo, - Coalesce(FastSeq( - extracted.postAggIR, - Die("Internal error: TableMapRows: row expression missing", extracted.postAggIR.typ), - )), - ) - - val rowIterationNeedsGlobals = Mentions(extracted.postAggIR, TableIR.globalName) - val globalsBc = - if (rowIterationNeedsGlobals) - globals.broadcast(ctx.theHailClassLoader) - else - null - - val fsBc = ctx.fsBc - val itF = { (i: Int, ctx: RVDContext, it: Iterator[Long]) => - val globalRegion = ctx.partitionRegion - val globals = if (rowIterationNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) - else - 0 - - val newRow = - f(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), globalRegion) - it.map(ptr => newRow(ctx.r, globals, ptr)) - } - - copy( - typ = newType, - rvd = rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key))(itF), - ) - } - - val scanInitNeedsGlobals = Mentions(extracted.init, TableIR.globalName) - val scanSeqNeedsGlobals = Mentions(extracted.seqPerElt, TableIR.globalName) - val rowIterationNeedsGlobals = Mentions(extracted.postAggIR, TableIR.globalName) - - val globalsBc = - if (rowIterationNeedsGlobals || scanInitNeedsGlobals || scanSeqNeedsGlobals) - globals.broadcast(ctx.theHailClassLoader) - else - null - - val spec = BufferSpec.blockedUncompressed - - // Order of operations: - // 1. init op on all aggs and serialize to byte array. - // 2. load in init op on each partition, seq op over partition, serialize. - // 3. load in partition aggregations, comb op as necessary, serialize. - // 4. load in partStarts, calculate newRow based on those results. - - val (_, initF) = CompileWithAggregators[AsmFunction2RegionLongUnit]( - ctx, - extracted.states, - FastSeq(( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)), - )), - FastSeq(classInfo[Region], LongInfo), - UnitInfo, - Begin(FastSeq(extracted.init)), - ) - - val serializeF = extracted.serialize(ctx, spec) - - val (_, eltSeqF) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( - ctx, - extracted.states, - FastSeq( - ( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)), - ), - ( - TableIR.rowName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)), - ), - ), - FastSeq(classInfo[Region], LongInfo, LongInfo), - UnitInfo, - extracted.seqPerElt, - ) - - val read = extracted.deserialize(ctx, spec) - val write = extracted.serialize(ctx, spec) - val combOpFNeedsPool = extracted.combOpFSerializedFromRegionPool(ctx, spec) - - val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = - CompileWithAggregators[AsmFunction3RegionLongLongLong]( - ctx, - extracted.states, - FastSeq( - ( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)), - ), - ( - TableIR.rowName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)), - ), - ), - FastSeq(classInfo[Region], LongInfo, LongInfo), - LongInfo, - Let( - FastSeq(extracted.resultRef.name -> extracted.results), - Coalesce(FastSeq( - extracted.postAggIR, - Die("Internal error: TableMapRows: row expression missing", extracted.postAggIR.typ), - )), - ), - ) - - // 1. init op on all aggs and write out to initPath - val initAgg = ctx.r.pool.scopedRegion { aggRegion => - ctx.r.pool.scopedRegion { fRegion => - val init = initF(ctx.theHailClassLoader, fsBc.value, ctx.taskContext, fRegion) - init.newAggState(aggRegion) - init(fRegion, globals.value.offset) - serializeF(ctx.theHailClassLoader, ctx.taskContext, aggRegion, init.getAggOffset()) - } - } - - if (ctx.getFlag("distributed_scan_comb_op") != null && extracted.shouldTreeAggregate) { - val fsBc = ctx.fs.broadcast - val tmpBase = ctx.createTmpPath("table-map-rows-distributed-scan") - val d = digitsNeeded(rvd.getNumPartitions) - val files = rvd.mapPartitionsWithIndex { (i, ctx, it) => - val path = tmpBase + "/" + partFile(d, i, TaskContext.get) - val globalRegion = ctx.freshRegion() - val globals = if (scanSeqNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) - else 0 - - ctx.r.pool.scopedSmallRegion { aggRegion => - val tc = SparkTaskContext.get() - val seq = eltSeqF(theHailClassLoaderForSparkWorkers, fsBc.value, tc, globalRegion) - - seq.setAggState( - aggRegion, - read(theHailClassLoaderForSparkWorkers, tc, aggRegion, initAgg), - ) - it.foreach { ptr => - seq(ctx.region, globals, ptr) - ctx.region.clear() - } - using(new DataOutputStream(fsBc.value.create(path))) { os => - val bytes = write(theHailClassLoaderForSparkWorkers, tc, aggRegion, seq.getAggOffset()) - os.writeInt(bytes.length) - os.write(bytes) - } - Iterator.single(path) - } - }.collect() - - val fileStack = new BoxedArrayBuilder[Array[String]]() - var filesToMerge: Array[String] = files - while (filesToMerge.length > 1) { - val nToMerge = filesToMerge.length / 2 - log.info(s"Running distributed combine stage with $nToMerge tasks") - fileStack += filesToMerge - - filesToMerge = - ContextRDD.weaken(SparkBackend.sparkContext("TableMapRows.execute").parallelize( - 0 until nToMerge, - nToMerge, - )) - .cmapPartitions { (ctx, it) => - val i = it.next() - assert(it.isEmpty) - val path = tmpBase + "/" + partFile(d, i, TaskContext.get) - val file1 = filesToMerge(i * 2) - val file2 = filesToMerge(i * 2 + 1) - - def readToBytes(is: DataInputStream): Array[Byte] = { - val len = is.readInt() - val b = new Array[Byte](len) - is.readFully(b) - b - } - - val b1 = using(new DataInputStream(fsBc.value.open(file1)))(readToBytes) - val b2 = using(new DataInputStream(fsBc.value.open(file2)))(readToBytes) - using(new DataOutputStream(fsBc.value.create(path))) { os => - val bytes = combOpFNeedsPool(() => - (ctx.r.pool, theHailClassLoaderForSparkWorkers, SparkTaskContext.get()) - )(b1, b2) - os.writeInt(bytes.length) - os.write(bytes) - } - Iterator.single(path) - }.collect() - } - fileStack += filesToMerge - - val itF = { (i: Int, ctx: RVDContext, it: Iterator[Long]) => - val globalRegion = ctx.freshRegion() - val globals = if (rowIterationNeedsGlobals || scanSeqNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) - else - 0 - val partitionAggs = { - var j = 0 - var x = i - val ab = new BoxedArrayBuilder[String] - while (j < fileStack.length) { - assert(x <= fileStack(j).length) - if (x % 2 != 0) { - x -= 1 - ab += fileStack(j)(x) - } - assert(x % 2 == 0) - x = x / 2 - j += 1 - } - assert(x == 0) - var b = initAgg - ab.result().reverseIterator.foreach { path => - def readToBytes(is: DataInputStream): Array[Byte] = { - val len = is.readInt() - val b = new Array[Byte](len) - is.readFully(b) - b - } - - b = combOpFNeedsPool(() => - (ctx.r.pool, theHailClassLoaderForSparkWorkers, SparkTaskContext.get()) - )(b, using(new DataInputStream(fsBc.value.open(path)))(readToBytes)) - } - b - } - - val aggRegion = ctx.freshRegion() - val hcl = theHailClassLoaderForSparkWorkers - val tc = SparkTaskContext.get() - val newRow = f(hcl, fsBc.value, tc, globalRegion) - val seq = eltSeqF(hcl, fsBc.value, tc, globalRegion) - var aggOff = read(hcl, tc, aggRegion, partitionAggs) - - val res = it.map { ptr => - newRow.setAggState(aggRegion, aggOff) - val newPtr = newRow(ctx.region, globals, ptr) - aggOff = newRow.getAggOffset() - seq.setAggState(aggRegion, aggOff) - seq(ctx.region, globals, ptr) - aggOff = seq.getAggOffset() - newPtr - } - res - } - copy( - typ = newType, - rvd = rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key))(itF), - ) - } - - // 2. load in init op on each partition, seq op over partition, write out. - val scanPartitionAggs = SpillingCollectIterator( - ctx.localTmpdir, - ctx.fs, - rvd.mapPartitionsWithIndex { (i, ctx, it) => - val globalRegion = ctx.partitionRegion - val globals = if (scanSeqNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) - else 0 - - SparkTaskContext.get().getRegionPool().scopedSmallRegion { aggRegion => - val hcl = theHailClassLoaderForSparkWorkers - val tc = SparkTaskContext.get() - val seq = eltSeqF(hcl, fsBc.value, tc, globalRegion) - - seq.setAggState(aggRegion, read(hcl, tc, aggRegion, initAgg)) - it.foreach { ptr => - seq(ctx.region, globals, ptr) - ctx.region.clear() - } - Iterator.single(write(hcl, tc, aggRegion, seq.getAggOffset())) - } - }, - ctx.getFlag("max_leader_scans").toInt, - ) - - // 3. load in partition aggregations, comb op as necessary, write back out. - val partAggs = scanPartitionAggs.scanLeft(initAgg)(combOpFNeedsPool(() => - (ctx.r.pool, ctx.theHailClassLoader, ctx.taskContext) - )) - val scanAggCount = rvd.getNumPartitions - val partitionIndices = new Array[Long](scanAggCount) - val scanAggsPerPartitionFile = ctx.createTmpPath("table-map-rows-scan-aggs-part") - using(ctx.fs.createNoCompression(scanAggsPerPartitionFile)) { os => - partAggs.zipWithIndex.foreach { case (x, i) => - if (i < scanAggCount) { - log.info(s"TableMapRows scan: serializing combined agg $i") - partitionIndices(i) = os.getPosition - os.writeInt(x.length) - os.write(x, 0, x.length) - } - } - } - - // 4. load in partStarts, calculate newRow based on those results. - val itF = { (i: Int, ctx: RVDContext, filePosition: Long, it: Iterator[Long]) => - val globalRegion = ctx.partitionRegion - val globals = if (rowIterationNeedsGlobals || scanSeqNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) - else - 0 - val partitionAggs = using(fsBc.value.openNoCompression(scanAggsPerPartitionFile)) { is => - is.seek(filePosition) - val aggSize = is.readInt() - val partAggs = new Array[Byte](aggSize) - var nread = is.read(partAggs, 0, aggSize) - var r = nread - while (r > 0 && nread < aggSize) { - r = is.read(partAggs, nread, aggSize - nread) - if (r > 0) nread += r - } - if (nread != aggSize) { - fatal(s"aggs read wrong number of bytes: $nread vs $aggSize") - } - partAggs - } - - val aggRegion = ctx.freshRegion() - val hcl = theHailClassLoaderForSparkWorkers - val tc = SparkTaskContext.get() - val newRow = f(hcl, fsBc.value, tc, globalRegion) - val seq = eltSeqF(hcl, fsBc.value, tc, globalRegion) - var aggOff = read(hcl, tc, aggRegion, partitionAggs) - - var idx = 0 - it.map { ptr => - newRow.setAggState(aggRegion, aggOff) - val off = newRow(ctx.region, globals, ptr) - seq.setAggState(aggRegion, newRow.getAggOffset()) - idx += 1 - seq(ctx.region, globals, ptr) - aggOff = seq.getAggOffset() - off - } - } - - copy( - typ = newType, - rvd = rvd.mapPartitionsWithIndexAndValue( - RVDType(rTyp.asInstanceOf[PStruct], typ.key), - partitionIndices, - )(itF), - ) - } - def orderBy(sortFields: IndexedSeq[SortField]): TableValue = { val newType = typ.copy(key = FastSeq()) val physicalKey = rvd.typ.key diff --git a/hail/hail/src/is/hail/expr/ir/lowering/ExecuteRelational.scala b/hail/hail/src/is/hail/expr/ir/lowering/ExecuteRelational.scala index 9661a87e061..5e70a7b0c2e 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/ExecuteRelational.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/ExecuteRelational.scala @@ -84,9 +84,8 @@ object ExecuteRelational { TableValueIntermediate(TableValue(ctx, typ, globals, rvd)) case TableMapGlobals(child, newGlobals) => TableValueIntermediate(recur(child).asTableValue(ctx).mapGlobals(newGlobals)) - case TableMapRows(child, newRow) => - val extracted = agg.Extract(newRow, r.requirednessAnalysis, isScan = true) - TableValueIntermediate(recur(child).asTableValue(ctx).mapRows(extracted)) + case ir: TableMapRows => + TableStageIntermediate(ctx.backend.tableToTableStage(ctx, ir, r)) case TableMapPartitions(child, globalName, partitionStreamName, body, _, allowedOverlap) => TableValueIntermediate( diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 995d0b668c4..e51f7ccfbc0 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -130,7 +130,6 @@ class Backend(abc.ABC): # Must match knownFlags in HailFeatureFlags.scala _flags_env_vars_and_defaults: ClassVar[Dict[str, Tuple[str, Optional[str]]]] = { "cachedir": ("HAIL_CACHE_DIR", None), - "distributed_scan_comb_op": ("HAIL_DEV_DISTRIBUTED_SCAN_COMB_OP", None), "gcs_requester_pays_buckets": ("HAIL_GCS_REQUESTER_PAYS_BUCKETS", None), "gcs_requester_pays_project": ("HAIL_GCS_REQUESTER_PAYS_PROJECT", None), "grouped_aggregate_buffer_size": ("HAIL_GROUPED_AGGREGATE_BUFFER_SIZE", "50"), diff --git a/hail/python/test/hail/expr/test_expr.py b/hail/python/test/hail/expr/test_expr.py index 6507dc42b0c..a9f789a3a11 100644 --- a/hail/python/test/hail/expr/test_expr.py +++ b/hail/python/test/hail/expr/test_expr.py @@ -13,7 +13,7 @@ from hail.expr.functions import _cdf_combine, _error_from_cdf, _result_from_raw_cdf from hail.expr.types import tarray, tbool, tcall, tfloat, tfloat32, tfloat64, tint, tint32, tint64, tstr, tstruct -from ..helpers import assert_evals_to, convert_struct_to_dict, qobtest, resource, test_timeout, with_flags +from ..helpers import assert_evals_to, convert_struct_to_dict, qobtest, resource, test_timeout def _test_many_equal(test_cases): @@ -704,7 +704,6 @@ def test_agg_densify(self): ] @qobtest - @with_flags(distributed_scan_comb_op='1') def test_densify_table(self): ht = hl.utils.range_table(100, n_partitions=33) ht = ht.annotate(arr=hl.range(100).map(lambda idx: hl.or_missing(idx == ht.idx, idx)))