Skip to content

Commit

Permalink
Merge pull request #3403 from armanbilge/bug/callback-leak
Browse files Browse the repository at this point in the history
Fix `CallbackStack` leak, restore specialized `IODeferred`
  • Loading branch information
armanbilge authored Feb 5, 2023
2 parents 3b0cb0b + 6fce2b3 commit f6adf0f
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 15 deletions.
2 changes: 2 additions & 0 deletions core/js/src/main/scala/cats/effect/CallbackStack.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ private final class CallbackStackOps[A](private val callbacks: js.Array[A => Uni

@inline def clear(): Unit =
callbacks.length = 0 // javascript is crazy!

@inline def pack(bound: Int): Int = bound
}

private object CallbackStack {
Expand Down
69 changes: 69 additions & 0 deletions core/jvm-native/src/main/scala/cats/effect/CallbackStack.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,75 @@ private final class CallbackStack[A](private[this] var callback: A => Unit)
def currentHandle(): CallbackStack.Handle = 0

def clear(): Unit = lazySet(null)

/**
* It is intended that `bound` be tracked externally and incremented on each clear(). Whenever
* pack is called, the number of empty cells removed from the stack is produced. It is
* expected that this value should be subtracted from `bound` for the subsequent pack/clear
* calls. It is permissible to pack on every clear() for simplicity, though it may be more
* reasonable to delay pack() calls until bound exceeds some reasonable threshold.
*
* The observation here is that it is cheapest to remove empty cells from the front of the
* list, but very expensive to remove them from the back of the list, and so we can be
* relatively aggressive about the former and conservative about the latter. In a "pack on
* every clear" protocol, the best possible case is if we're always clearing at the very front
* of the list. In this scenario, pack is always O(1). Conversely, the worst possible scenario
* is when we're clearing at the *end* of the list. In this case, we won't actually remove any
* cells until exactly half the list is emptied (thus, the number of empty cells is equal to
* the number of full cells). In this case, the final pack is O(n), while the accumulated
* wasted packs (which will fail to remove any items) will total to O((n / 2)^2). Thus, in the
* worst case, we would need O((n / 2)^2 + n) operations to clear out the waste, where the
* waste would be accumulated by n / 2 total clears, meaning that the marginal cost added to
* clear is O(n/2 + 2), which is to say, O(n).
*
* In order to reduce this to a sub-linear cost, we need to pack less frequently, with higher
* bounds, as the number of outstanding clears increases. Thus, rather than packing on each
* clear, we should pack on the even log clears (1, 2, 4, 8, etc). For cases where most of the
* churn is at the head of the list, this remains essentially O(1) and clears frequently. For
* cases where the churn is buried deeper in the list, it becomes O(log n) per clear
* (amortized). This still biases the optimizations towards the head of the list, but ensures
* that packing will still inevitably reach all of the garbage cells.
*/
def pack(bound: Int): Int = {
// the first cell is always retained
val got = get()
if (got ne null)
got.packInternal(bound, 0, this)
else
0
}

@tailrec
private def packInternal(bound: Int, removed: Int, parent: CallbackStack[A]): Int = {
if (callback == null) {
val child = get()

// doing this cas here ultimately deoptimizes contiguous empty chunks
if (!parent.compareAndSet(this, child)) {
// if we're contending with another pack(), just bail and let them continue
removed
} else {
if (child == null) {
// bottomed out
removed
} else {
// note this can cause the bound to go negative, which is fine
child.packInternal(bound - 1, removed + 1, parent)
}
}
} else {
val child = get()
if (child == null) {
// bottomed out
removed
} else {
if (bound > 0)
child.packInternal(bound - 1, removed, this)
else
removed
}
}
}
}

private object CallbackStack {
Expand Down
38 changes: 24 additions & 14 deletions core/shared/src/main/scala/cats/effect/IODeferred.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ import cats.effect.syntax.all._
import cats.syntax.all._
import cats.~>

import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}

private final class IODeferred[A] extends Deferred[IO, A] {
import IODeferred.Sentinel

private[this] val cell = new AtomicReference[AnyRef](Sentinel)
private[this] val callbacks = CallbackStack[Right[Nothing, A]](null)
private[this] val clearCounter = new AtomicInteger

def complete(a: A): IO[Boolean] = IO {
if (cell.compareAndSet(Sentinel, a.asInstanceOf[AnyRef])) {
Expand All @@ -45,23 +46,32 @@ private final class IODeferred[A] extends Deferred[IO, A] {
IO.cont[A, A](new Cont[IO, A, A] {
def apply[G[_]: MonadCancelThrow] = {
(cb: Either[Throwable, A] => Unit, get: G[A], lift: IO ~> G) =>
MonadCancel[G] uncancelable { poll =>
val gga = lift {
IO {
val stack = callbacks.push(cb)
val handle = stack.currentHandle()
MonadCancel[G] uncancelable {
poll =>
val gga = lift {
IO {
val stack = callbacks.push(cb)
val handle = stack.currentHandle()

val back = cell.get()
if (back eq Sentinel) {
poll(get).onCancel(lift(IO(stack.clearCurrent(handle))))
} else {
stack.clearCurrent(handle)
back.asInstanceOf[A].pure[G]
def clear(): Unit = {
stack.clearCurrent(handle)
val clearCount = clearCounter.incrementAndGet()
if ((clearCount & (clearCount - 1)) == 0) // power of 2
clearCounter.addAndGet(-callbacks.pack(clearCount))
()
}

val back = cell.get()
if (back eq Sentinel) {
poll(get).onCancel(lift(IO(clear())))
} else {
clear()
back.asInstanceOf[A].pure[G]
}
}
}
}

gga.flatten
gga.flatten
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class DeferredSpec extends BaseSpec { outer =>
import cats.syntax.all._

for {
d <- Deferred[IO, Int]
d <- deferredI
attemptCompletion = { (n: Int) => d.complete(n).void }
res <- List(
IO.race(attemptCompletion(1), attemptCompletion(2)).void,
Expand All @@ -170,5 +170,42 @@ class DeferredSpec extends BaseSpec { outer =>
} yield r
}

"handle lots of canceled gets in parallel" in real {
List(10, 100, 1000)
.traverse_ { n =>
deferredU
.flatMap { d =>
(d.get.background.surround(IO.cede).replicateA_(n) *> d
.complete(())).background.surround {
d.get.as(1).parReplicateA(n).map(_.sum must be_==(n))
}
}
.replicateA_(100)
}
.as(true)
}

"handle adversarial cancelations without loss of callbacks" in ticked { implicit ticker =>
val test = for {
d <- deferredU

range = 0.until(512)
fibers <- range.toVector.traverse(_ => d.get.start <* IO.sleep(1.millis))

// these are mostly randomly chosen
// the consecutive runs are significant, but only loosely so
// the point is to trigger packing but ensure it isn't always successful
toCancel = List(12, 23, 201, 405, 1, 7, 17, 27, 127, 203, 204, 207, 2, 3, 4, 5)
_ <- toCancel.traverse_(fibers(_).cancel)

_ <- d.complete(())
remaining = range.toSet -- toCancel

// this will deadlock if any callbacks are lost
_ <- remaining.toList.traverse_(fibers(_).join.void)
} yield ()

test must completeAs(())
}
}
}

0 comments on commit f6adf0f

Please sign in to comment.