Skip to content

Commit

Permalink
reuse allocated buffer in stream reads for TCP/Unix sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
seigert committed Mar 20, 2024
1 parent 1279244 commit d668f8f
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 74 deletions.
153 changes: 95 additions & 58 deletions io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package io
package net

import com.comcast.ip4s.{IpAddress, SocketAddress}
import cats.effect.Async
import cats.effect.{Async, Resource}
import cats.effect.std.Mutex
import cats.syntax.all._

Expand All @@ -33,82 +33,125 @@ import java.nio.channels.{AsynchronousSocketChannel, CompletionHandler}
import java.nio.{Buffer, ByteBuffer}

private[net] trait SocketCompanionPlatform {

/** Creates a [[Socket]] instance for given `AsynchronousSocketChannel`
* with 16 KiB max. read chunk size and exclusive access guards for both reads abd writes.
*/
private[net] def forAsync[F[_]: Async](
ch: AsynchronousSocketChannel
): F[Socket[F]] =
(Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) =>
new AsyncSocket[F](ch, readMutex, writeMutex)
forAsync(ch, maxReadChunkSize = 16384, withExclusiveReads = true, withExclusiveWrites = true)

/** Creates a [[Socket]] instance for given `AsynchronousSocketChannel`.
*
* @param ch async socket channel for actual reads and writes
* @param maxReadChunkSize maximum chunk size for [[Socket#reads]] method
* @param withExclusiveReads set to `true` if reads should be guarded by mutex
* @param withExclusiveWrites set to `true` if writes should be guarded by mutex
*/
private[net] def forAsync[F[_]](
ch: AsynchronousSocketChannel,
maxReadChunkSize: Int,
withExclusiveReads: Boolean = false,
withExclusiveWrites: Boolean = false
)(implicit F: Async[F]): F[Socket[F]] = {
def maybeMutex(maybe: Boolean) = F.defer(if (maybe) Mutex[F].map(Some(_)) else F.pure(None))
(maybeMutex(withExclusiveReads), maybeMutex(withExclusiveWrites)).mapN {
(readMutex, writeMutex) => new AsyncSocket[F](ch, readMutex, writeMutex, maxReadChunkSize)
}
}

private[net] abstract class BufferedReads[F[_]](
readMutex: Mutex[F]
readMutex: Option[Mutex[F]],
writeMutex: Option[Mutex[F]],
maxReadChunkSize: Int
)(implicit F: Async[F])
extends Socket[F] {
private[this] final val defaultReadSize = 8192
private[this] var readBuffer: ByteBuffer = ByteBuffer.allocate(defaultReadSize)
private def lock(mutex: Option[Mutex[F]]): Resource[F, Unit] =
mutex match {
case Some(mutex) => mutex.lock
case None => Resource.unit
}

private def withReadBuffer[A](size: Int)(f: ByteBuffer => F[A]): F[A] =
readMutex.lock.surround {
F.delay {
if (readBuffer.capacity() < size)
readBuffer = ByteBuffer.allocate(size)
else
(readBuffer: Buffer).limit(size)
f(readBuffer)
}.flatten
lock(readMutex).surround {
F.delay(ByteBuffer.allocate(size)).flatMap(f)
}

/** Performs a single channel read operation in to the supplied buffer. */
protected def readChunk(buffer: ByteBuffer): F[Int]

/** Copies the contents of the supplied buffer to a `Chunk[Byte]` and clears the buffer contents. */
private def releaseBuffer(buffer: ByteBuffer): F[Chunk[Byte]] =
F.delay {
val read = buffer.position()
val result =
if (read == 0) Chunk.empty
else {
val dest = new Array[Byte](read)
(buffer: Buffer).flip()
buffer.get(dest)
Chunk.array(dest)
}
(buffer: Buffer).clear()
result
}
/** Performs a channel write operation(-s) from the supplied buffer.
*
* Write could be performed multiple times till all buffer remaining contents are written.
*/
protected def writeChunk(buffer: ByteBuffer): F[Unit]

def read(max: Int): F[Option[Chunk[Byte]]] =
withReadBuffer(max) { buffer =>
readChunk(buffer).flatMap { read =>
if (read < 0) F.pure(None)
else releaseBuffer(buffer).map(Some(_))
readChunk(buffer).map { read =>
if (read < 0) None
else if (buffer.position() == 0) Some(Chunk.empty)
else {
(buffer: Buffer).flip()
Some(Chunk.byteBuffer(buffer.asReadOnlyBuffer()))
}
}
}

def readN(max: Int): F[Chunk[Byte]] =
withReadBuffer(max) { buffer =>
def go: F[Chunk[Byte]] =
readChunk(buffer).flatMap { readBytes =>
if (readBytes < 0 || buffer.position() >= max)
releaseBuffer(buffer)
else go
if (readBytes < 0 || buffer.position() >= max) {
(buffer: Buffer).flip()
F.pure(Chunk.byteBuffer(buffer.asReadOnlyBuffer()))
} else go
}
go
}

def reads: Stream[F, Byte] =
Stream.repeatEval(read(defaultReadSize)).unNoneTerminate.unchunks
Stream.resource(lock(readMutex)).flatMap { _ =>
Stream.unfoldChunkEval(ByteBuffer.allocate(maxReadChunkSize)) { case buffer =>
readChunk(buffer).flatMap { read =>
if (read < 0) none[(Chunk[Byte], ByteBuffer)].pure
else if (buffer.position() == 0) {
(Chunk.empty[Byte] -> buffer).some.pure
} else if (buffer.remaining() == 0) {
val bytes = buffer.asReadOnlyBuffer()
val fresh = ByteBuffer.allocate(maxReadChunkSize)
(bytes: Buffer).flip()
(Chunk.byteBuffer(bytes) -> fresh).some.pure
} else {
val bytes = buffer.duplicate().asReadOnlyBuffer()
val slice = buffer.slice()
(bytes: Buffer).flip()
(Chunk.byteBuffer(bytes) -> slice).some.pure
}
}
}
}

def write(bytes: Chunk[Byte]): F[Unit] =
lock(writeMutex).surround {
F.delay(bytes.toByteBuffer).flatMap(writeChunk)
}

def writes: Pipe[F, Byte, Nothing] =
_.chunks.foreach(write)
def writes: Pipe[F, Byte, Nothing] = { in =>
Stream.resource(lock(writeMutex)).flatMap { _ =>
in.chunks.foreach(bytes => writeChunk(bytes.toByteBuffer))
}
}
}

private final class AsyncSocket[F[_]](
ch: AsynchronousSocketChannel,
readMutex: Mutex[F],
writeMutex: Mutex[F]
readMutex: Option[Mutex[F]],
writeMutex: Option[Mutex[F]],
maxReadChunkSize: Int
)(implicit F: Async[F])
extends BufferedReads[F](readMutex) {
extends BufferedReads[F](readMutex, writeMutex, maxReadChunkSize) {

protected def readChunk(buffer: ByteBuffer): F[Int] =
F.async[Int] { cb =>
Expand All @@ -120,24 +163,18 @@ private[net] trait SocketCompanionPlatform {
F.delay(Some(endOfInput.voidError))
}

def write(bytes: Chunk[Byte]): F[Unit] = {
def go(buff: ByteBuffer): F[Unit] =
F.async[Int] { cb =>
ch.write(
buff,
null,
new IntCompletionHandler(cb)
)
F.delay(Some(endOfOutput.voidError))
}.flatMap { written =>
if (written >= 0 && buff.remaining() > 0)
go(buff)
else F.unit
}
writeMutex.lock.surround {
F.delay(bytes.toByteBuffer).flatMap(go)
protected def writeChunk(buffer: ByteBuffer): F[Unit] =
F.async[Int] { cb =>
ch.write(
buffer,
null,
new IntCompletionHandler(cb)
)
F.delay(Some(endOfOutput.voidError))
}.flatMap { written =>
if (written < 0 || buffer.remaining() == 0) F.unit
else writeChunk(buffer)
}
}

def localAddress: F[SocketAddress[IpAddress]] =
F.delay(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import cats.effect.std.Mutex
import cats.effect.syntax.all._
import cats.syntax.all._
import com.comcast.ip4s.{IpAddress, SocketAddress}
import fs2.{Chunk, Stream}
import fs2.Stream
import fs2.io.file.{Files, Path}
import fs2.io.net.Socket
import java.nio.ByteBuffer
Expand Down Expand Up @@ -89,29 +89,36 @@ private[unixsocket] trait UnixSocketsCompanionPlatform {
private def makeSocket[F[_]: Async](
ch: SocketChannel
): F[Socket[F]] =
(Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) =>
new AsyncSocket[F](ch, readMutex, writeMutex)
makeSocket(ch, maxReadChunkSize = 16384, withExclusiveReads = true, withExclusiveWrites = true)

private def makeSocket[F[_]](
ch: SocketChannel,
maxReadChunkSize: Int,
withExclusiveReads: Boolean,
withExclusiveWrites: Boolean
)(implicit F: Async[F]): F[Socket[F]] = {
def maybeMutex(maybe: Boolean) = F.defer(if (maybe) Mutex[F].map(Some(_)) else F.pure(None))
(maybeMutex(withExclusiveReads), maybeMutex(withExclusiveWrites)).mapN {
(readMutex, writeMutex) => new AsyncSocket[F](ch, readMutex, writeMutex, maxReadChunkSize)
}
}

private final class AsyncSocket[F[_]](
ch: SocketChannel,
readMutex: Mutex[F],
writeMutex: Mutex[F]
readMutex: Option[Mutex[F]],
writeMutex: Option[Mutex[F]],
maxReadChunkSize: Int
)(implicit F: Async[F])
extends Socket.BufferedReads[F](readMutex) {
extends Socket.BufferedReads[F](readMutex, writeMutex, maxReadChunkSize) {

def readChunk(buff: ByteBuffer): F[Int] =
F.blocking(ch.read(buff)).cancelable(close)
protected def readChunk(buffer: ByteBuffer): F[Int] =
F.blocking(ch.read(buffer)).cancelable(close)

def write(bytes: Chunk[Byte]): F[Unit] = {
def go(buff: ByteBuffer): F[Unit] =
F.blocking(ch.write(buff)).cancelable(close) *>
F.delay(buff.remaining <= 0).ifM(F.unit, go(buff))

writeMutex.lock.surround {
F.delay(bytes.toByteBuffer).flatMap(go)
protected def writeChunk(buffer: ByteBuffer): F[Unit] =
F.blocking(ch.write(buffer)).cancelable(close).flatMap { _ =>
if (buffer.remaining() == 0) F.unit
else writeChunk(buffer)
}
}

def localAddress: F[SocketAddress[IpAddress]] = raiseIpAddressError
def remoteAddress: F[SocketAddress[IpAddress]] = raiseIpAddressError
Expand Down

0 comments on commit d668f8f

Please sign in to comment.