Skip to content

Commit

Permalink
BestFirstSearch: extract queues into a new class
Browse files Browse the repository at this point in the history
Move `trackState` into `getActiveSplits` as well, since that is what
should constitute a visit.
  • Loading branch information
kitbellew committed Sep 20, 2024
1 parent 4022f2f commit 76afee2
Showing 1 changed file with 46 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,12 @@ private class BestFirstSearch private (range: Set[Range])(implicit
depth: Int = 0,
maxCost: Int = Integer.MAX_VALUE,
): State = {
def newGeneration = new mutable.PriorityQueue[State]()
var Q = newGeneration
var generations: List[mutable.PriorityQueue[State]] = Nil
def addGeneration() = if (Q.nonEmpty) {
generations = Q :: generations
Q = newGeneration
}
Q += start

implicit val Q: StateQueue = new StateQueue(depth)
def enqueue(state: State) = Q.enqueue(state)
enqueue(start)

// TODO(olafur) this while loop is waaaaaaaaaaaaay tooo big.
while (true) {
while (!Q.isEmpty()) {
val curr = Q.dequeue()
if (curr.depth >= tokens.length) return curr

Expand All @@ -115,8 +108,6 @@ private class BestFirstSearch private (range: Set[Range])(implicit
noOptZones.contains(leftTok)

if (noOptZone || shouldEnterState(curr)) {
trackState(curr, depth, Q.length)

if (explored > style.runner.maxStateVisits) {
complete(deepestYet)
throw new Error.SearchStateExploded(
Expand All @@ -131,7 +122,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
optimizer.dequeueOnNewStatements && curr.allAltAreNL &&
!(depth == 0 && noOptZone) &&
(leftTok.is[Token.KwElse] || statementStarts.contains(curr.depth))
) addGeneration()
) Q.addGeneration()

val noBlockClose = start == curr && 0 != maxCost || !noOptZone ||
!optimizer.recurseOnBlocks
Expand Down Expand Up @@ -168,7 +159,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
else shortestPath(nextState, opt.token, depth + 1, maxCost = 0)
val furtherState =
if (null == nextNextState) null
else traverseSameLine(nextNextState, depth)
else traverseSameLine(nextNextState)
if (null == furtherState) if (killOnFail(opt)) null else nextState
else if (
furtherState.appliedPenalty > nextNextState.appliedPenalty
Expand Down Expand Up @@ -202,24 +193,15 @@ private class BestFirstSearch private (range: Set[Range])(implicit
}
}
}

if (Q.isEmpty) {
if (generations.isEmpty) return null

Q = generations.head
generations = generations.tail
}
}

// unreachable
null
}

private def getActiveSplits(
ft: FormatToken,
state: State,
maxCost: Int,
private def getActiveSplits(ft: FormatToken, state: State, maxCost: Int)(
implicit Q: StateQueue,
): Seq[Split] = {
trackState(state)
val useProvided = ft.meta.formatOff || !ft.inside(range)
val active = state.policy.execute(Decision(ft, routes(state.depth)))
.filter(x => x.isActive && x.cost <= maxCost)
Expand All @@ -235,38 +217,38 @@ private class BestFirstSearch private (range: Set[Range])(implicit
splits.sortBy(_.cost)
}

private def trackState(state: State, depth: Int, queueSize: Int)(implicit
style: ScalafmtConfig,
): Unit = {
if (state.depth > deepestYet.depth) deepestYet = state
style.runner.event(FormatEvent.VisitToken(tokens(state.depth)))
visits(state.depth) += 1
private def trackState(state: State)(implicit Q: StateQueue): Unit = {
val idx = state.depth
if (idx > deepestYet.depth) deepestYet = state
initStyle.runner.event(FormatEvent.VisitToken(tokens(idx)))
visits(idx) += 1
explored += 1
style.runner.event(FormatEvent.Explored(explored, depth, queueSize))
initStyle.runner.event(FormatEvent.Explored(explored, Q.nested, Q.length))
}

/** Follow states having single active non-newline split
*/
@tailrec
private def traverseSameLine(state: State, depth: Int): State =
private def traverseSameLine(
state: State,
)(implicit queue: StateQueue): State =
if (state.depth >= tokens.length) state
else {
val splitToken = tokens(state.depth)
implicit val style: ScalafmtConfig = styleMap.at(splitToken)
trackState(state, depth, 0)
getActiveSplits(splitToken, state, Int.MaxValue) match {
case Seq() => null // dead end if empty
case Seq(split) =>
if (split.isNL) state
else {
style.runner.event(FormatEvent.Enqueue(split))
val nextState = state.next(split, nextAllAltAreNL = false)
traverseSameLine(nextState, depth)
traverseSameLine(nextState)
}
case ss
if state.appliedPenalty == 0 &&
RightParenOrBracket(splitToken.right) =>
traverseSameLineZeroCost(ss.filter(_.cost == 0), state, depth)
traverseSameLineZeroCost(ss.filter(_.cost == 0), state)
case _ => state
}
}
Expand All @@ -275,8 +257,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
private def traverseSameLineZeroCost(
splits: Seq[Split],
state: State,
depth: Int,
)(implicit style: ScalafmtConfig): State = splits match {
)(implicit style: ScalafmtConfig, queue: StateQueue): State = splits match {
case Seq(split) if !split.isNL =>
style.runner.event(FormatEvent.Enqueue(split))
val nextState = state.next(split, nextAllAltAreNL = false)
Expand All @@ -286,9 +267,8 @@ private class BestFirstSearch private (range: Set[Range])(implicit
val nextToken = tokens(nextState.depth)
if (RightParenOrBracket(nextToken.right)) {
implicit val style: ScalafmtConfig = styleMap.at(nextToken)
trackState(nextState, depth, 0)
val nextSplits = getActiveSplits(nextToken, nextState, maxCost = 0)
traverseSameLineZeroCost(nextSplits, nextState, depth)
traverseSameLineZeroCost(nextSplits, nextState)
} else nextState
}
case _ => state
Expand Down Expand Up @@ -390,4 +370,29 @@ object BestFirstSearch {
private def useNoOptZones(implicit style: ScalafmtConfig): Boolean =
style.runner.optimizer.disableOptimizationsInsideSensitiveAreas

class StateQueue(val nested: Int)(implicit stateOrdering: Ordering[State]) {
private def newGeneration = new mutable.PriorityQueue[State]()
var generation: mutable.PriorityQueue[State] = newGeneration
var generations: List[mutable.PriorityQueue[State]] = Nil

def addGeneration(): Unit = if (generation.nonEmpty) {
generations = generation :: generations
generation = newGeneration
}

def dequeue(): State = generation.dequeue()
def enqueue(state: State): Unit = generation.enqueue(state)
def length: Int = generation.length
@tailrec
final def isEmpty(): Boolean = generation.isEmpty && {
generations match {
case head :: tail =>
generation = head
generations = tail
isEmpty()
case _ => true
}
}
}

}

0 comments on commit 76afee2

Please sign in to comment.