Skip to content

Commit

Permalink
[SPARK-16719][ML] Random Forests should communicate fewer trees on ea…
Browse files Browse the repository at this point in the history
…ch iteration

## What changes were proposed in this pull request?

RandomForest currently sends the entire forest to each worker on each iteration. This is because (a) the node queue is FIFO and (b) the closure references the entire array of trees (topNodes). (a) causes RFs to handle splits in many trees, especially early on in learning. (b) sends all trees explicitly.

This PR:
(a) Change the RF node queue to be FILO (a stack), so that RFs tend to focus on 1 or a few trees before focusing on others.
(b) Change topNodes to pass only the trees required on that iteration.

## How was this patch tested?

Unit tests:
* Existing tests for correctness of tree learning
* Manually modifying code and running tests to verify that a small number of trees are communicated on each iteration
  * This last item is hard to test via unit tests given the current APIs.

Author: Joseph K. Bradley <[email protected]>

Closes apache#14359 from jkbradley/rfs-fewer-trees.
  • Loading branch information
jkbradley committed Sep 23, 2016
1 parent a4aeb76 commit 947b8c6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
* findSplits() method during initialization, after which each continuous feature becomes
* an ordered discretized feature with at most maxBins possible values.
*
* The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes
* The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes
* lie at the periphery of the tree being trained. If multiple trees are being trained at once,
* then this queue contains nodes from all of them. Each iteration works roughly as follows:
* On the master node:
Expand Down Expand Up @@ -161,31 +161,42 @@ private[spark] object RandomForest extends Logging {
None
}

// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
/*
Stack of nodes to train: (treeIndex, node)
The reason this is a stack is that we train many trees at once, but we want to focus on
completing trees, rather than training all simultaneously. If we are splitting nodes from
1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
training the same tree in the next iteration. This focus allows us to send fewer trees to
workers on each iteration; see topNodesForGroup below.
*/
val nodeStack = new mutable.Stack[(Int, LearningNode)]

val rng = new Random()
rng.setSeed(seed)

// Allocate and queue root nodes.
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex))))

timer.stop("init")

while (nodeQueue.nonEmpty) {
while (nodeStack.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
// Sanity check (should never occur):
assert(nodesForGroup.nonEmpty,
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")

// Only send trees to worker if they contain nodes being split this iteration.
val topNodesForGroup: Map[Int, LearningNode] =
nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap

// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup,
treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache)
timer.stop("findBestSplits")
}

Expand Down Expand Up @@ -334,13 +345,14 @@ private[spark] object RandomForest extends Logging {
*
* @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodes Root node for each tree. Used for matching instances with nodes.
* @param topNodesForGroup For each tree in group, tree index -> root node.
* Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
* @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
* where nodeIndexInfo stores the index in the group and the
* feature subsets (if using feature subsets).
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
* @param nodeStack Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
* each value in the array is the data point's node Id
Expand All @@ -351,11 +363,11 @@ private[spark] object RandomForest extends Logging {
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
metadata: DecisionTreeMetadata,
topNodes: Array[LearningNode],
topNodesForGroup: Map[Int, LearningNode],
nodesForGroup: Map[Int, Array[LearningNode]],
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
splits: Array[Array[Split]],
nodeQueue: mutable.Queue[(Int, LearningNode)],
nodeStack: mutable.Stack[(Int, LearningNode)],
timer: TimeTracker = new TimeTracker,
nodeIdCache: Option[NodeIdCache] = None): Unit = {

Expand Down Expand Up @@ -437,7 +449,8 @@ private[spark] object RandomForest extends Logging {
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
val nodeIndex =
topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
agg
Expand Down Expand Up @@ -593,10 +606,10 @@ private[spark] object RandomForest extends Logging {

// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.leftChild.get))
nodeStack.push((treeIndex, node.leftChild.get))
}
if (!rightChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.rightChild.get))
nodeStack.push((treeIndex, node.rightChild.get))
}

logDebug("leftChildIndex = " + node.leftChild.get.id +
Expand Down Expand Up @@ -1029,7 +1042,7 @@ private[spark] object RandomForest extends Logging {
* will be needed; this allows an adaptive number of nodes since different nodes may require
* different amounts of memory (if featureSubsetStrategy is not "all").
*
* @param nodeQueue Queue of nodes to split.
* @param nodeStack Queue of nodes to split.
* @param maxMemoryUsage Bound on size of aggregate statistics.
* @return (nodesForGroup, treeToNodeToIndexInfo).
* nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
Expand All @@ -1041,7 +1054,7 @@ private[spark] object RandomForest extends Logging {
* The feature indices are None if not subsampling features.
*/
private[tree] def selectNodesToSplit(
nodeQueue: mutable.Queue[(Int, LearningNode)],
nodeStack: mutable.Stack[(Int, LearningNode)],
maxMemoryUsage: Long,
metadata: DecisionTreeMetadata,
rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
Expand All @@ -1054,8 +1067,8 @@ private[spark] object RandomForest extends Logging {
var numNodesInGroup = 0
// If maxMemoryInMB is set very small, we want to still try to split 1 node,
// so we allow one iteration if memUsage == 0.
while (nodeQueue.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
val (treeIndex, node) = nodeQueue.head
while (nodeStack.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
val (treeIndex, node) = nodeStack.top
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
Some(SamplingUtils.reservoirSampleAndCount(Range(0,
Expand All @@ -1066,7 +1079,7 @@ private[spark] object RandomForest extends Logging {
// Check if enough memory remains to add this node to the group.
val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
nodeQueue.dequeue()
nodeStack.pop()
mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
node
mutableTreeToNodeToIndexInfo
Expand Down Expand Up @@ -1109,5 +1122,4 @@ private[spark] object RandomForest extends Logging {
3 * totalBins
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy,
Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.collection.OpenHashMap
Expand Down Expand Up @@ -239,12 +240,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
val nodeStack = new mutable.Stack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)

// don't enqueue leaf nodes into node queue
assert(nodeQueue.isEmpty)
assert(nodeStack.isEmpty)

// set impurity and predict for topNode
assert(topNode.stats !== null)
Expand Down Expand Up @@ -281,12 +282,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
val nodeStack = new mutable.Stack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)

// don't enqueue a node into node queue if its impurity is 0.0
assert(nodeQueue.isEmpty)
assert(nodeStack.isEmpty)

// set impurity and predict for topNode
assert(topNode.stats !== null)
Expand Down Expand Up @@ -393,16 +394,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val failString = s"Failed on test with:" +
s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
val nodeStack = new mutable.Stack[(Int, LearningNode)]
val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees)
Range(0, numTrees).foreach { treeIndex =>
topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1)
nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
nodeStack.push((treeIndex, topNodes(treeIndex)))
}
val rng = new scala.util.Random(seed = seed)
val (nodesForGroup: Map[Int, Array[LearningNode]],
treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)

assert(nodesForGroup.size === numTrees, failString)
assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree
Expand Down Expand Up @@ -546,7 +547,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}

}

private object RandomForestSuite {
Expand Down

0 comments on commit 947b8c6

Please sign in to comment.