From 32629eb8f13c1f72b28cca8bac714c937dcd3fdf Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Fri, 31 May 2013 15:45:26 +0800 Subject: [PATCH 1/2] Consolidations of shuffle files from different map tasks --- .../spark/BlockStoreShuffleFetcher.scala | 11 +- .../main/scala/spark/MapOutputTracker.scala | 6 +- .../main/scala/spark/PairRDDFunctions.scala | 2 + .../scala/spark/scheduler/MapStatus.scala | 6 +- .../spark/scheduler/ShuffleMapTask.scala | 25 ++- .../scala/spark/storage/BlockManager.scala | 3 + .../main/scala/spark/storage/DiskStore.scala | 28 ++- .../spark/storage/ShuffleBlockManager.scala | 167 ++++++++++++++++-- 8 files changed, 206 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index e1fb02157a..753eed8b4f 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -22,16 +22,19 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) + val splitsByAddress = new HashMap[BlockManagerId, HashMap[Int, Long]] + for ((address, groupId, size) <- statuses) { + val groupedSplits = splitsByAddress.getOrElseUpdate(address, new HashMap[Int, Long]) + val currSize = groupedSplits.getOrElse(groupId, 0L) + if (size > currSize) groupedSplits.put(groupId, size) } val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { case (address, splits) => - (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) + (address, splits.toSeq.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) } + logDebug("Fetched grouped splits: " + blocksByAddress) def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[(K, V)] = { val blockId = blockPair._1 val blockOption = blockPair._2 diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index fde597ffd1..d3e75fb3d0 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -117,7 +117,7 @@ private[spark] class MapOutputTracker extends Logging { private val fetching = new HashSet[Int] // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle - def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Int, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") @@ -280,7 +280,7 @@ private[spark] object MapOutputTracker { private def convertMapStatuses( shuffleId: Int, reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + statuses: Array[MapStatus]): Array[(BlockManagerId, Int, Long)] = { assert (statuses != null) statuses.map { status => @@ -288,7 +288,7 @@ private[spark] object MapOutputTracker { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing an output location for shuffle " + shuffleId)) } else { - (status.location, decompressSize(status.compressedSizes(reduceId))) + (status.location, status.groupId, decompressSize(status.compressedSizes(reduceId))) } } } diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 2b0e697337..b6a016ec50 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -69,12 +69,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } else if (mapSideCombine) { val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass) + logInfo("serializerClass=" + serializerClass) partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) val values = new ShuffledRDD[K, V](self, partitioner, serializerClass) + logInfo("serializerClass=" + serializerClass) values.mapPartitions(aggregator.combineValuesByKey(_), true) } } diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index 203abb917b..baabae8e71 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -8,19 +8,21 @@ import java.io.{ObjectOutput, ObjectInput, Externalizable} * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. * The map output sizes are compressed using MapOutputTracker.compressSize. */ -private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) +private[spark] class MapStatus(var location: BlockManagerId, var groupId: Int, var compressedSizes: Array[Byte]) extends Externalizable { - def this() = this(null, null) // For deserialization only + def this() = this(null, 0, null) // For deserialization only def writeExternal(out: ObjectOutput) { location.writeExternal(out) + out.writeInt(groupId) out.writeInt(compressedSizes.length) out.write(compressedSizes) } def readExternal(in: ObjectInput) { location = BlockManagerId(in) + groupId = in.readInt() compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 95647389c3..2b9f633c99 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -132,29 +132,26 @@ private[spark] class ShuffleMapTask( val blockManager = SparkEnv.get.blockManager var shuffle: ShuffleBlocks = null - var buckets: ShuffleWriterGroup = null + var group: ShuffleWriterGroup = null try { // Obtain all the block writers for shuffle blocks. val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) - buckets = shuffle.acquireWriters(partition) + group = shuffle.acquireWriters(partition) // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, taskContext)) { val pair = elem.asInstanceOf[(Any, Any)] val bucketId = dep.partitioner.getPartition(pair._1) - buckets.writers(bucketId).write(pair) + group.writers(bucketId).write(pair) } // Commit the writes. Get the size of each bucket block (total block size). var totalBytes = 0L - val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() - val size = writer.size() - totalBytes += size - MapOutputTracker.compressSize(size) + val compressedSizes: Array[Byte] = group.writers.map { writer: BlockObjectWriter => + totalBytes += writer.commit() + MapOutputTracker.compressSize(writer.size()) } // Update shuffle metrics. @@ -162,18 +159,18 @@ private[spark] class ShuffleMapTask( shuffleMetrics.shuffleBytesWritten = totalBytes metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - return new MapStatus(blockManager.blockManagerId, compressedSizes) + return new MapStatus(blockManager.blockManagerId, group.id, compressedSizes) } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. - if (buckets != null) { - buckets.writers.foreach(_.revertPartialWrites()) + if (group != null) { + group.writers.foreach(_.revertPartialWrites()) } throw e } finally { // Release the writers back to the shuffle block manager. - if (shuffle != null && buckets != null) { - shuffle.releaseWriters(buckets) + if (shuffle != null && group != null) { + shuffle.releaseWriters(group) } // Execute the callbacks on task completion. taskContext.executeOnCompleteCallbacks() diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index d35c43f194..33ff20bdd4 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -287,6 +287,7 @@ private[spark] class BlockManager( * never deletes (recent) items. */ def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + shuffleBlockManager.closeBlock(blockId) diskStore.getValues(blockId, serializer).orElse( sys.error("Block " + blockId + " not found on disk, though it should be")) } @@ -382,6 +383,8 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (ShuffleBlockManager.isShuffle(blockId)) { + //close the shuffle Writers for blockId + shuffleBlockManager.closeBlock(blockId) return diskStore.getBytes(blockId) match { case Some(bytes) => Some(bytes) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c7281200e7..57eeda3e03 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -26,21 +26,27 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) with Logging { class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) - extends BlockObjectWriter(blockId) { + extends BlockObjectWriter(blockId) with Logging { - private val f: File = createFile(blockId /*, allowAppendExisting */) + private var f: File = createFile(blockId) // The file channel, used for repositioning / truncating the file. private var channel: FileChannel = null private var bs: OutputStream = null private var objOut: SerializationStream = null private var lastValidPosition = 0L + private var initialPosition = 0L override def open(): DiskBlockObjectWriter = { val fos = new FileOutputStream(f, true) channel = fos.getChannel() bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) objOut = serializer.newInstance().serializeStream(bs) + + //commit possible file header + commit() + initialPosition = lastValidPosition + this } @@ -59,7 +65,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Flush the partial writes, and set valid length to be the length of the entire file. // Return the number of bytes written for this commit. override def commit(): Long = { - // NOTE: Flush the serializer first and then the compressed/buffered output stream objOut.flush() bs.flush() val prevPos = lastValidPosition @@ -68,11 +73,28 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } override def revertPartialWrites() { + // Revert by discarding current writes, except that if no values have been committed, + // we revert by recreate the file (otherwise there are errors when reading objects from the file later on + if (lastValidPosition == initialPosition) + recerateFile() + else + discardWrites() + } + + private def recerateFile () { + close () + f.delete() + f = createFile(blockId) + open() + } + + private def discardWrites () { // Discard current writes. We do this by flushing the outstanding writes and // truncate the file to the last valid position. objOut.flush() bs.flush() channel.truncate(lastValidPosition) + channel.position(lastValidPosition) } override def write(value: Any) { diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 49eabfb0d2..7114d15bed 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -2,10 +2,17 @@ package spark.storage import spark.serializer.Serializer +import java.util.concurrent.{ConcurrentLinkedQueue,ConcurrentHashMap} +import java.util.concurrent.atomic.AtomicInteger -private[spark] -class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) +import spark.util.MetadataCleaner + +import scala.collection.JavaConversions +import scala.collection.mutable.ArrayBuffer +import spark.Logging +import spark.SparkException +private[spark] class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) private[spark] trait ShuffleBlocks { @@ -15,30 +22,158 @@ trait ShuffleBlocks { private[spark] -class ShuffleBlockManager(blockManager: BlockManager) { +class ShuffleBlockManager(blockManager: BlockManager) extends Logging { + initLogging() + + val metadataCleaner = new MetadataCleaner("ShuffleBlockManager", this.cleanup) def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { new ShuffleBlocks { + val pool = getPool(shuffleId) // Get a group of writers for a map task. - override def acquireWriters(mapId: Int): ShuffleWriterGroup = { - val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 - val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open() - } - new ShuffleWriterGroup(mapId, writers) - } + override def acquireWriters(mapId: Int) = pool.acquireGroup(numBuckets, serializer) + override def releaseWriters(group: ShuffleWriterGroup) = pool.returnGroup(group) + } - override def releaseWriters(group: ShuffleWriterGroup) = { - // Nothing really to release here. + } + + /** + * Get a Shuffle Pool + */ + private def getPool(shuffleId: Int) : ShuffleBlocksPool = { + pools.putIfAbsent(shuffleId,new ShuffleBlocksPool(shuffleId)) + pools.get(shuffleId) + } + + /** + * Close one block. + * This method is not synchronized as suppose each shuffle block is only + * retrieved by one thread + */ + def closeBlock(blockId: String) { + val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r + blockId match { + case regex(shuffleId, groupId, bucketId) => + logDebug("closeBlock shuffleId: " + shuffleId + ", groupId: " + groupId + ", bucketId: " + bucketId) + val pool = getPool(shuffleId.toInt) + if (pool != null) + pool.closeBlock(groupId.toInt,bucketId.toInt) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") } - } } + + /** + * Clean up the closed, unused pools. + * TODO, 1. need to handle the case when a shufflePool is removed after this pool is acquired + * for writing (e.g. in some failure recovery progress). 2. Better consider ttl + */ + def cleanup(cleanupTime: Long){ + JavaConversions.asScalaConcurrentMap(pools).retain( (shuffleId,pool) => pool.allGroupsClosed() ) + } + + class ShuffleBlocksPool (val shuffleId: Int) { + // Keep track of all groups that have been generated. We can use this array buffer to + // get the complete list of groups (and writers) for a particular shuffle. Default the + // size to 32. + val allGroups = new ArrayBuffer[ShuffleWriterGroup](32) + + // Keep track of groups that are not currently in use (i.e. no threads are using them) + val freeGroups = new ConcurrentLinkedQueue[ShuffleWriterGroup] + + // Used to generate the next group id. + val nextGroupID = new AtomicInteger(0) + + // Check if this pool closed and not available for further use again. + val isClosed = false + + /** + * Acquire a new group from pool + */ + def acquireGroup(numBuckets: Int, serializer: Serializer) : ShuffleWriterGroup = { + //TODO. throws exception now. This needs to be handled. maybe reopen it. + if (isClosed) + throw new SparkException("ShuffleBlocksPool "+ shuffleId +" is closed and can not be used to acquired new blocks") + var group = freeGroups.poll() + if (group == null) { + val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + val groupId = nextGroupID.getAndIncrement() + val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, groupId) + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) + } + group = new ShuffleWriterGroup(groupId, writers) + openGroup(group) + allGroups += group + } + group + } + + /** + * Return a group to pool + */ + def returnGroup(group: ShuffleWriterGroup) { + freeGroups.add(group) + } + + /** + * Close one block writer. Didn't check whether the group is in free list + * as this is called right before the shuffle block is retrieved + */ + def closeBlock(groupId:Int, bucketId:Int) { + if (allGroups(groupId.toInt).writers(bucketId.toInt).isOpen) + allGroups(groupId.toInt).writers(bucketId.toInt).close() + } + + /** + * Close all writers in a group. Reserved for possible optimization + */ + def closeGroup(groupId:Int){ + allGroups(groupId.toInt).writers.map { _.close()} + } + + /** + * Check if a group is closed. Reserved for finer-grained metadata cleanup + */ + def groupClosed(group: ShuffleWriterGroup): Boolean = { + group.writers.forall(!_.isOpen) + } + + /** + * Check if all the groups in this shuffle is released (no writing is going-on) + */ + def allGroupsReleased() : Boolean = { + return (freeGroups.size == nextGroupID.get - 1) + } + + /** + * Check if all the groups are released and closed.. ready to be cleaned. + */ + def allGroupsClosed(): Boolean = { + //check if all groups are released + if (allGroupsReleased()) + allGroups.forall(groupClosed(_)) + else + false + } + + /** + * Open all writers in a group + */ + private def openGroup(group: ShuffleWriterGroup) : ShuffleWriterGroup = { + //open all the writers + group.writers.map { writer => { if(!writer.isOpen) writer.open()} } + group + } } + //keep track of pools for all shuffles indexed by Id + val pools = new ConcurrentHashMap[Int,ShuffleBlocksPool] -private[spark] -object ShuffleBlockManager { +} + +private[spark] object ShuffleBlockManager { // Returns the block id for a given shuffle block. def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { From d53d332b7ebc37c0eb5c99e8835d826170f4c876 Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Tue, 4 Jun 2013 14:26:31 +0800 Subject: [PATCH 2/2] update shuffle consolidation code per review comments --- .../main/scala/spark/MapOutputTracker.scala | 2 + .../main/scala/spark/PairRDDFunctions.scala | 4 +- .../main/scala/spark/storage/DiskStore.scala | 4 +- .../spark/storage/ShuffleBlockManager.scala | 84 +++++++++++-------- 4 files changed, 56 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index d3e75fb3d0..5e69e927dd 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -117,6 +117,8 @@ private[spark] class MapOutputTracker extends Logging { private val fetching = new HashSet[Int] // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle + // Return an array of map output locations of the specific reduceId, one for each ShuffleMapTask, in the form of + // (BlockManagerId, groupId of the shuffle file, size of the shuffle file when the task writes its output) def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Int, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index b6a016ec50..fb8699b9cd 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -69,14 +69,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } else if (mapSideCombine) { val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass) - logInfo("serializerClass=" + serializerClass) + logDebug("serializerClass=" + serializerClass) partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) val values = new ShuffledRDD[K, V](self, partitioner, serializerClass) - logInfo("serializerClass=" + serializerClass) + logDebug("serializerClass=" + serializerClass) values.mapPartitions(aggregator.combineValuesByKey(_), true) } } diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 57eeda3e03..59cb47865d 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -76,12 +76,12 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Revert by discarding current writes, except that if no values have been committed, // we revert by recreate the file (otherwise there are errors when reading objects from the file later on if (lastValidPosition == initialPosition) - recerateFile() + recreateFile() else discardWrites() } - private def recerateFile () { + private def recreateFile () { close () f.delete() f = createFile(blockId) diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 7114d15bed..132c4994ad 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -11,8 +11,14 @@ import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer import spark.Logging import spark.SparkException +import scala.collection.JavaConverters._ -private[spark] class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) +private[spark] class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) { + def open() { + //open all the writers + writers.map { writer => if (!writer.isOpen) writer.open() } + } +} private[spark] trait ShuffleBlocks { @@ -20,7 +26,21 @@ trait ShuffleBlocks { def releaseWriters(group: ShuffleWriterGroup) } - +/** + * On each slave, ShuffleBlockManager maintains a shuffle block pool for each shuffle. + * + * Each pool maintains a list of shuffle block group; a ShuffleMapTask acquires a free group + * when it needs to write its results, and returns the group when it's done. + * + * Each group maintains a list of block writers, each for a different bucket (reduce partition). + * + * Each block writer is closed when the BlockManager receives a shuffle request for that block + * (i.e., all map tasks are done) and will not be re-opened again + * + * If we need to re-run a map task afterwards, Spark will actually re-run all the map tasks on + * the same slave; these tasks will then acquire new groups, which effectively discard the + * previous shuffle outputs for all these map tasks + */ private[spark] class ShuffleBlockManager(blockManager: BlockManager) extends Logging { initLogging() @@ -41,8 +61,13 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { * Get a Shuffle Pool */ private def getPool(shuffleId: Int) : ShuffleBlocksPool = { - pools.putIfAbsent(shuffleId,new ShuffleBlocksPool(shuffleId)) - pools.get(shuffleId) + val pool = pools.get(shuffleId) + if (pool == null) { + pools.putIfAbsent(shuffleId, new ShuffleBlocksPool(shuffleId)) + pools.get(shuffleId) + } + else + pool } /** @@ -57,7 +82,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { logDebug("closeBlock shuffleId: " + shuffleId + ", groupId: " + groupId + ", bucketId: " + bucketId) val pool = getPool(shuffleId.toInt) if (pool != null) - pool.closeBlock(groupId.toInt,bucketId.toInt) + pool.closeBlock(groupId.toInt, bucketId.toInt) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block") @@ -70,7 +95,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { * for writing (e.g. in some failure recovery progress). 2. Better consider ttl */ def cleanup(cleanupTime: Long){ - JavaConversions.asScalaConcurrentMap(pools).retain( (shuffleId,pool) => pool.allGroupsClosed() ) + pools.asScala.retain( (shuffleId,pool) => pool.allGroupsClosed() ) } class ShuffleBlocksPool (val shuffleId: Int) { @@ -92,22 +117,22 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { * Acquire a new group from pool */ def acquireGroup(numBuckets: Int, serializer: Serializer) : ShuffleWriterGroup = { - //TODO. throws exception now. This needs to be handled. maybe reopen it. - if (isClosed) - throw new SparkException("ShuffleBlocksPool "+ shuffleId +" is closed and can not be used to acquired new blocks") - var group = freeGroups.poll() - if (group == null) { - val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 - val groupId = nextGroupID.getAndIncrement() - val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, groupId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) - } - group = new ShuffleWriterGroup(groupId, writers) - openGroup(group) - allGroups += group + //TODO. throws exception now. This needs to be handled. maybe reopen it. + if (isClosed) + throw new SparkException("ShuffleBlocksPool "+ shuffleId +" is closed and can not be used to acquired new blocks") + var group = freeGroups.poll() + if (group == null) { + val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + val groupId = nextGroupID.getAndIncrement() + val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, groupId) + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) } - group + group = new ShuffleWriterGroup(groupId, writers) + group.open() + allGroups += group + } + group } /** @@ -129,8 +154,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { /** * Close all writers in a group. Reserved for possible optimization */ - def closeGroup(groupId:Int){ - allGroups(groupId.toInt).writers.map { _.close()} + def closeGroup(groupId:Int) { + allGroups(groupId.toInt).writers.map { _.close() } } /** @@ -144,7 +169,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { * Check if all the groups in this shuffle is released (no writing is going-on) */ def allGroupsReleased() : Boolean = { - return (freeGroups.size == nextGroupID.get - 1) + freeGroups.size == nextGroupID.get - 1 } /** @@ -157,19 +182,10 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { else false } - - /** - * Open all writers in a group - */ - private def openGroup(group: ShuffleWriterGroup) : ShuffleWriterGroup = { - //open all the writers - group.writers.map { writer => { if(!writer.isOpen) writer.open()} } - group - } } //keep track of pools for all shuffles indexed by Id - val pools = new ConcurrentHashMap[Int,ShuffleBlocksPool] + val pools = new ConcurrentHashMap[Int, ShuffleBlocksPool] }