diff --git a/core/src/main/scala/akka/persistence/dynamodb/internal/MonotonicTimestamps.scala b/core/src/main/scala/akka/persistence/dynamodb/internal/MonotonicTimestamps.scala new file mode 100644 index 0000000..e5a5bc0 --- /dev/null +++ b/core/src/main/scala/akka/persistence/dynamodb/internal/MonotonicTimestamps.scala @@ -0,0 +1,265 @@ +/* + * Copyright (C) 2024 Lightbend Inc. + */ + +package akka.persistence.dynamodb.internal + +import akka.Done +import akka.actor.typed.ActorRef +import akka.actor.typed.ActorSystem +import akka.actor.typed.Behavior +import akka.actor.typed.Extension +import akka.actor.typed.ExtensionId +import akka.actor.typed.SupervisorStrategy +import akka.actor.typed.scaladsl.AskPattern.Askable +import akka.actor.typed.scaladsl.Behaviors +import akka.actor.typed.scaladsl.TimerScheduler +import akka.annotation.InternalApi +import akka.persistence.Persistence + +import scala.annotation.tailrec +import scala.collection.immutable.SortedSet +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.jdk.CollectionConverters.IteratorHasAsScala + +import java.net.URLEncoder +import java.nio.charset.StandardCharsets +import java.time.Instant +import java.time.{ Duration => JDuration } +import java.util.concurrent.ConcurrentHashMap +import java.time.temporal.ChronoUnit + +object MonotonicTimestamps extends ExtensionId[MonotonicTimestamps] { + override def createExtension(system: ActorSystem[_]): MonotonicTimestamps = new MonotonicTimestamps(system) + + def get(system: ActorSystem[_]): MonotonicTimestamps = createExtension(system) + + /** INTERNAL API */ + @InternalApi + private[internal] final class PerPlugin( + system: ActorSystem[_], + name: String, + numRanges: Int, + rangeForPid: String => Int) { + // only written to by a single actor + private val byPid = (0 until numRanges).map { _ => new ConcurrentHashMap[String, Record]() }.toVector + private val rangeActors = (0 until numRanges).map { range => + val actorName = + URLEncoder.encode(s"dynamodb-persistence-monotonic-timestamps-${name}-$range", StandardCharsets.UTF_8) + val behavior = + Behaviors + .supervise( + Behaviors + .setup[Any] { context => + if (byPid(range).isEmpty) { + Behaviors.withTimers { timers => + cleanerBehavior(range, SortedSet.empty, timers) + } + } else { + val recordSet = + byPid(range).values.iterator.asScala + .foldLeft(SortedSet.empty[Record]) { (acc, v) => + acc.incl(v) + } + + Behaviors.withTimers { timers => + scheduleNextCleanup(Instant.now(), recordSet.head.nextTimestamp, timers) + cleanerBehavior(range, recordSet, timers) + } + } + } + .narrow[(Record, ActorRef[Done])]) + .onFailure(SupervisorStrategy.restart) + + system.systemActorOf(behavior, actorName) + }.toVector + + def minTimestampFor(pid: String): Option[Instant] = + byPid(rangeForPid(pid)).get(pid) match { + case null => None + case record => Some(record.nextTimestamp) + } + + def recordTimestampFor(pid: String, timestamp: Instant): Future[Done] = { + rangeActors(rangeForPid(pid)) + .ask[Done]((Record(pid, timestamp.plus(1, ChronoUnit.MICROS)), _))(1.second, system.scheduler) + } + + private def scheduleNextCleanup(now: Instant, nextTimestamp: Instant, timers: TimerScheduler[Any]): Unit = { + val nextCleanupIn = { + val millis = + try { + JDuration.between(now, nextTimestamp).toMillis / 2 + } catch { + case _: ArithmeticException => 10000 // ten second maximum + } + + // minimum 1 second, maximum 10 seconds + (millis.min(10000).max(1)).millis + } + + timers.startSingleTimer(Cleanup, Cleanup, nextCleanupIn) + } + + private[internal] def cleanerBehavior( + range: Int, + recordSet: SortedSet[Record], + timers: TimerScheduler[Any]): Behavior[Any] = + Behaviors.receive { (context, msg) => + msg match { + case Cleanup => + // next timestamp will be greater than this + val keepAfter = InstantFactory.now() + + // adding the nano ensures that this will compare greater than + // any record with timestamp of keepAfter + // "" as pid (not legal pid) will compare less than any record with same timestamp + // net effect is to swap the clusivity of rangeFrom/rangeTo + val pivotRecord = Record("", keepAfter.plusNanos(1)) + + val recordsToKeep = recordSet.rangeFrom(pivotRecord) + val recordsToDrop = recordSet.rangeTo(pivotRecord) + + val kept = + recordsToDrop.foldLeft(recordsToKeep) { (rtk, record) => + val pid = record.pid + + if (byPid(range).remove(pid, record)) rtk + else { + context.log.warn( + "Concurrent modification of state: this should not happen. Report issue at github.com/akka/akka-persistence-dynamodb") + rtk.incl(byPid(range).get(pid)) + } + } + + if (kept.nonEmpty) { scheduleNextCleanup(keepAfter, kept.head.nextTimestamp, timers) } + + cleanerBehavior(range, kept, timers) + + case (rec: Record, replyTo: ActorRef[Nothing]) => + val pid = rec.pid + val nextRecordSet = + byPid(range).get(pid) match { + case null => + if (byPid(range).putIfAbsent(pid, rec) eq null) { + replyTo.unsafeUpcast[Done] ! Done + recordSet.incl(rec) + } else { + context.log.warn( + "Timestamp not updated for persistence ID [{}]. " + + "Report issue at github.com/akka/akka-persistence-dynamodb", + pid) + + // no reply + recordSet + } + + case oldRecord => + if (oldRecord.nextTimestamp.isBefore(rec.nextTimestamp)) { + oldRecord match { + case expected if expected eq oldRecord => + replyTo.unsafeUpcast[Done] ! Done + recordSet.excl(oldRecord).incl(rec) + + case unexpected => + context.log.warn( + "Timestamp not updated for persistence ID [{}]. " + + "Report issue at github.com/akka/akka-persistence-dynamodb", + pid) + recordSet.excl(oldRecord).incl(unexpected) + } + } else { + context.log.warn( + "Ignoring attempt to set timestamp for persistence ID [{}] to earlier. " + + "existing=[{}] attempted=[{}]", + pid, + oldRecord.nextTimestamp, + rec.nextTimestamp) + + replyTo.unsafeUpcast[Done] ! Done + recordSet + } + } + + if (!timers.isTimerActive(Cleanup) && nextRecordSet.nonEmpty) { + scheduleNextCleanup(Instant.now(), rec.nextTimestamp, timers) + } + + cleanerBehavior(range, nextRecordSet, timers) + + case _ => Behaviors.unhandled + } + } + } + + /** INTERNAL API */ + @InternalApi + private[internal] case class Record(pid: String, nextTimestamp: Instant) + + /** INTERNAL API */ + @InternalApi + private[internal] val Cleanup = "Cleanup" + + private object Record { + implicit val ordering: Ordering[Record] = + new Ordering[Record] { + override def compare(x: Record, y: Record): Int = + x.nextTimestamp.compareTo(y.nextTimestamp) match { + case 0 => x.pid.compareTo(y.pid) + case result => result + } + } + } +} + +final class MonotonicTimestamps(system: ActorSystem[_]) extends Extension { + import MonotonicTimestamps.PerPlugin + + private val persistenceExt = Persistence(system) + private val numRanges = + // minimize contention by having a number of ranges that's at least available processors + Runtime.getRuntime.availableProcessors match { + case lt2 if lt2 < 2 => 1 + case gt1024 if gt1024 > 1024 => 1024 + case numProcs => + val clz = Integer.numberOfLeadingZeros(numProcs - 1) + 1 << (32 - clz) // next highest power of 2 + } + + private val rawRanges = persistenceExt.sliceRanges(numRanges) + private val starts = rawRanges.map(_.head).toArray + private val rangeForPid = (pid: String) => { + val slice = persistenceExt.sliceForPersistenceId(pid) + + @tailrec + def iter(lo: Int, hi: Int): Int = + if ((lo + 1) >= hi) lo + else { + val pivot = (lo + hi) / 2 + val p = starts(pivot) + + if (p == slice) pivot + else if (p < slice) iter(pivot, hi) + else iter(lo, pivot) + } + + iter(0, starts.length) + } + + private val perPlugin = new ConcurrentHashMap[String, PerPlugin]() + + def minTimestampFor(plugin: String): String => Option[Instant] = { + val pp = + perPlugin.computeIfAbsent(plugin, _ => new PerPlugin(system, plugin, numRanges, rangeForPid)) + + pp.minTimestampFor _ + } + + def recordTimestampFor(plugin: String): (String, Instant) => Future[Done] = { + val pp = + perPlugin.computeIfAbsent(plugin, _ => new PerPlugin(system, plugin, numRanges, rangeForPid)) + + pp.recordTimestampFor _ + } +} diff --git a/core/src/main/scala/akka/persistence/dynamodb/journal/DynamoDBJournal.scala b/core/src/main/scala/akka/persistence/dynamodb/journal/DynamoDBJournal.scala index b408ce4..76d5d65 100644 --- a/core/src/main/scala/akka/persistence/dynamodb/journal/DynamoDBJournal.scala +++ b/core/src/main/scala/akka/persistence/dynamodb/journal/DynamoDBJournal.scala @@ -10,7 +10,6 @@ import java.util.concurrent.CompletionException import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.util.Failure -import scala.util.Success import scala.util.Try import akka.Done @@ -25,6 +24,7 @@ import akka.persistence.SerializedEvent import akka.persistence.dynamodb.DynamoDBSettings import akka.persistence.dynamodb.internal.InstantFactory import akka.persistence.dynamodb.internal.JournalDao +import akka.persistence.dynamodb.internal.MonotonicTimestamps import akka.persistence.dynamodb.internal.PubSub import akka.persistence.dynamodb.internal.SerializedEventMetadata import akka.persistence.dynamodb.internal.SerializedJournalItem @@ -102,6 +102,10 @@ private[dynamodb] final class DynamoDBJournal(config: Config, cfgPath: String) if (settings.journalPublishEvents) Some(PubSub(system)) else None + private val monotonicTimestamps = MonotonicTimestamps(system) + private val minTimestampFor: String => Option[Instant] = monotonicTimestamps.minTimestampFor(cfgPath) + private val recordTimestampFor: (String, Instant) => Future[Done] = monotonicTimestamps.recordTimestampFor(cfgPath) + // if there are pending writes when an actor restarts we must wait for // them to complete before we can read the highest sequence number or we will miss it private val writesInProgress = new java.util.HashMap[String, Future[Seq[Try[Unit]]]]() @@ -112,73 +116,86 @@ private[dynamodb] final class DynamoDBJournal(config: Config, cfgPath: String) override def asyncWriteMessages(messages: Seq[AtomicWrite]): Future[Seq[Try[Unit]]] = { def atomicWrite(atomicWrite: AtomicWrite): Future[Seq[Try[Unit]]] = { - val serialized: Try[Seq[SerializedJournalItem]] = Try { - atomicWrite.payload.map { pr => - val (event, tags) = pr.payload match { - case Tagged(payload, tags) => - (payload.asInstanceOf[AnyRef], tags) - case other => - (other.asInstanceOf[AnyRef], Set.empty[String]) - } - - val serializedEvent = event match { - case s: SerializedEvent => s // already serialized - case _ => - val bytes = serialization.serialize(event).get - val serializer = serialization.findSerializerFor(event) - val manifest = Serializers.manifestFor(serializer, event) - new SerializedEvent(bytes, serializer.identifier, manifest) - } + val serialized: Future[Seq[SerializedJournalItem]] = + atomicWrite.payload + .foldLeft(Future.successful(List.empty[SerializedJournalItem])) { (acc, pr) => + acc.flatMap { previousItems => + val (event, tags) = pr.payload match { + case Tagged(payload, tags) => + (payload.asInstanceOf[AnyRef], tags) + case other => + (other.asInstanceOf[AnyRef], Set.empty[String]) + } - val metadata = pr.metadata.map { meta => - val m = meta.asInstanceOf[AnyRef] - val serializedMeta = serialization.serialize(m).get - val metaSerializer = serialization.findSerializerFor(m) - val metaManifest = Serializers.manifestFor(metaSerializer, m) - val id: Int = metaSerializer.identifier - SerializedEventMetadata(id, metaManifest, serializedMeta) - } + val serializedEvent = event match { + case s: SerializedEvent => s // already serialized + case _ => + val bytes = serialization.serialize(event).get + val serializer = serialization.findSerializerFor(event) + val manifest = Serializers.manifestFor(serializer, event) + new SerializedEvent(bytes, serializer.identifier, manifest) + } - // monotonically increasing, at least 1 microsecond more than previous timestamp - val timestamp = InstantFactory.now() - - SerializedJournalItem( - pr.persistenceId, - pr.sequenceNr, - timestamp, - InstantFactory.EmptyTimestamp, - Some(serializedEvent.bytes), - serializedEvent.serializerId, - serializedEvent.serializerManifest, - pr.writerUuid, - tags, - metadata) - } - } + val metadata = pr.metadata.map { meta => + val m = meta.asInstanceOf[AnyRef] + val serializedMeta = serialization.serialize(m).get + val metaSerializer = serialization.findSerializerFor(m) + val metaManifest = Serializers.manifestFor(metaSerializer, m) + val id: Int = metaSerializer.identifier + SerializedEventMetadata(id, metaManifest, serializedMeta) + } - serialized match { - case Success(writes) => - journalDao - .writeEvents(writes) - .map { _ => - pubSub.foreach { ps => - atomicWrite.payload.zip(writes).foreach { case (pr, serialized) => - ps.publish(pr, serialized.writeTimestamp) - } + // monotonically increasing, at least 1 microsecond more than previous timestamp + val timestampFut = { + val now = InstantFactory.now() + minTimestampFor(pr.persistenceId) + .fold(Future.successful(now)) { min => + if (min.isAfter(now)) { + log.warning( + "Detected possible clock skew: current timestamp [{}], required for monotonicity [{}]", + now, + min) + recordTimestampFor(pr.persistenceId, min).map(_ => min)(ExecutionContext.parasitic) + } else Future.successful(now) + } } - Nil // successful writes + + timestampFut.map { timestamp => + SerializedJournalItem( + pr.persistenceId, + pr.sequenceNr, + timestamp, + InstantFactory.EmptyTimestamp, + Some(serializedEvent.bytes), + serializedEvent.serializerId, + serializedEvent.serializerManifest, + pr.writerUuid, + tags, + metadata) :: previousItems + }(ExecutionContext.parasitic) } - .recoverWith { case e: CompletionException => - e.getCause match { - case error: ProvisionedThroughputExceededException => // reject retryable errors - Future.successful(atomicWrite.payload.map(_ => Failure(error))) - case error => // otherwise journal failure - Future.failed(error) + } + .map(_.reverse)(ExecutionContext.parasitic) + + serialized.flatMap { writes => + journalDao + .writeEvents(writes) + .map { _ => + pubSub.foreach { ps => + atomicWrite.payload.zip(writes).foreach { case (pr, serialized) => + ps.publish(pr, serialized.writeTimestamp) } } - - case Failure(exc) => - Future.failed(exc) + Nil // successful writes + } + .recoverWith { case e: CompletionException => + e.getCause match { + case error: ProvisionedThroughputExceededException => // reject retryable errors + Future.successful(atomicWrite.payload.map(_ => Failure(error))) + case error => // otherwise journal failure + Future.failed(error) + } + } } } @@ -231,16 +248,25 @@ private[dynamodb] final class DynamoDBJournal(config: Config, cfgPath: String) pendingWrite.flatMap { _ => if (toSequenceNr == Long.MaxValue && max == Long.MaxValue) { // this is the normal case, highest sequence number from last event - query - .internalCurrentEventsByPersistenceId(persistenceId, fromSequenceNr, toSequenceNr, includeDeleted = true) - .runWith(Sink.fold(0L) { (_, item) => - // payload is empty for deleted item - if (item.payload.isDefined) { - val repr = deserializeItem(serialization, item) - recoveryCallback(repr) - } - item.seqNr - }) + val lastItem = + query + .internalCurrentEventsByPersistenceId(persistenceId, fromSequenceNr, toSequenceNr, includeDeleted = true) + .runWith(Sink.fold(Option.empty[SerializedJournalItem]) { (_, item) => + // payload is empty for deleted item + if (item.payload.isDefined) { + val repr = deserializeItem(serialization, item) + recoveryCallback(repr) + } + Some(item) + }) + + lastItem.flatMap { itemOpt => + itemOpt.fold(Future.successful[Long](0)) { item => + recordTimestampFor(item.persistenceId, item.eventTimestamp).map { _ => + item.seqNr + }(ExecutionContext.parasitic) + } + }(ExecutionContext.parasitic) } else if (toSequenceNr <= 0) { // no replay journalDao.readHighestSequenceNr(persistenceId)