Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: reuse allocated buffer in stream reads for TCP/Unix sockets #3411

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,13 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq(
),
ProblemFilters.exclude[ReversedMissingMethodProblem](
"fs2.io.file.PosixFileAttributes.fs2$io$file$PosixFileAttributes$$super#Code"
),
// private classes
ProblemFilters.exclude[DirectMissingMethodProblem](
"fs2.io.net.SocketCompanionPlatform#AsyncSocket.this"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"fs2.io.net.unixsocket.UnixSocketsCompanionPlatform#AsyncSocket.this"
)
)

Expand Down
153 changes: 95 additions & 58 deletions io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package io
package net

import com.comcast.ip4s.{IpAddress, SocketAddress}
import cats.effect.Async
import cats.effect.{Async, Resource}
import cats.effect.std.Mutex
import cats.syntax.all._

Expand All @@ -33,82 +33,125 @@ import java.nio.channels.{AsynchronousSocketChannel, CompletionHandler}
import java.nio.{Buffer, ByteBuffer}

private[net] trait SocketCompanionPlatform {

/** Creates a [[Socket]] instance for given `AsynchronousSocketChannel`
* with 16 KiB max. read chunk size and exclusive access guards for both reads abd writes.
*/
private[net] def forAsync[F[_]: Async](
ch: AsynchronousSocketChannel
): F[Socket[F]] =
(Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) =>
new AsyncSocket[F](ch, readMutex, writeMutex)
forAsync(ch, maxReadChunkSize = 16384, withExclusiveReads = true, withExclusiveWrites = true)

/** Creates a [[Socket]] instance for given `AsynchronousSocketChannel`.
*
* @param ch async socket channel for actual reads and writes
* @param maxReadChunkSize maximum chunk size for [[Socket#reads]] method
* @param withExclusiveReads set to `true` if reads should be guarded by mutex
* @param withExclusiveWrites set to `true` if writes should be guarded by mutex
*/
private[net] def forAsync[F[_]](
ch: AsynchronousSocketChannel,
maxReadChunkSize: Int,
withExclusiveReads: Boolean = false,
withExclusiveWrites: Boolean = false
)(implicit F: Async[F]): F[Socket[F]] = {
def maybeMutex(maybe: Boolean) = F.defer(if (maybe) Mutex[F].map(Some(_)) else F.pure(None))
(maybeMutex(withExclusiveReads), maybeMutex(withExclusiveWrites)).mapN {
(readMutex, writeMutex) => new AsyncSocket[F](ch, readMutex, writeMutex, maxReadChunkSize)
}
}

private[net] abstract class BufferedReads[F[_]](
readMutex: Mutex[F]
readMutex: Option[Mutex[F]],
writeMutex: Option[Mutex[F]],
maxReadChunkSize: Int
)(implicit F: Async[F])
extends Socket[F] {
private[this] final val defaultReadSize = 8192
private[this] var readBuffer: ByteBuffer = ByteBuffer.allocate(defaultReadSize)
private def lock(mutex: Option[Mutex[F]]): Resource[F, Unit] =
mutex match {
case Some(mutex) => mutex.lock
case None => Resource.unit
}

private def withReadBuffer[A](size: Int)(f: ByteBuffer => F[A]): F[A] =
readMutex.lock.surround {
F.delay {
if (readBuffer.capacity() < size)
readBuffer = ByteBuffer.allocate(size)
else
(readBuffer: Buffer).limit(size)
f(readBuffer)
}.flatten
lock(readMutex).surround {
F.delay(ByteBuffer.allocate(size)).flatMap(f)
}

/** Performs a single channel read operation in to the supplied buffer. */
protected def readChunk(buffer: ByteBuffer): F[Int]

/** Copies the contents of the supplied buffer to a `Chunk[Byte]` and clears the buffer contents. */
private def releaseBuffer(buffer: ByteBuffer): F[Chunk[Byte]] =
F.delay {
val read = buffer.position()
val result =
if (read == 0) Chunk.empty
else {
val dest = new Array[Byte](read)
(buffer: Buffer).flip()
buffer.get(dest)
Chunk.array(dest)
}
(buffer: Buffer).clear()
result
}
/** Performs a channel write operation(-s) from the supplied buffer.
*
* Write could be performed multiple times till all buffer remaining contents are written.
*/
protected def writeChunk(buffer: ByteBuffer): F[Unit]

def read(max: Int): F[Option[Chunk[Byte]]] =
withReadBuffer(max) { buffer =>
readChunk(buffer).flatMap { read =>
if (read < 0) F.pure(None)
else releaseBuffer(buffer).map(Some(_))
readChunk(buffer).map { read =>
if (read < 0) None
else if (buffer.position() == 0) Some(Chunk.empty)
else {
(buffer: Buffer).flip()
Some(Chunk.byteBuffer(buffer.asReadOnlyBuffer()))
}
}
}

def readN(max: Int): F[Chunk[Byte]] =
withReadBuffer(max) { buffer =>
def go: F[Chunk[Byte]] =
readChunk(buffer).flatMap { readBytes =>
if (readBytes < 0 || buffer.position() >= max)
releaseBuffer(buffer)
else go
if (readBytes < 0 || buffer.position() >= max) {
(buffer: Buffer).flip()
F.pure(Chunk.byteBuffer(buffer.asReadOnlyBuffer()))
} else go
}
go
}

def reads: Stream[F, Byte] =
Stream.repeatEval(read(defaultReadSize)).unNoneTerminate.unchunks
Stream.resource(lock(readMutex)).flatMap { _ =>
Stream.unfoldChunkEval(ByteBuffer.allocate(maxReadChunkSize)) { case buffer =>
readChunk(buffer).flatMap { read =>
if (read < 0) none[(Chunk[Byte], ByteBuffer)].pure
else if (buffer.position() == 0) {
(Chunk.empty[Byte] -> buffer).some.pure
} else if (buffer.remaining() == 0) {
val bytes = buffer.asReadOnlyBuffer()
val fresh = ByteBuffer.allocate(maxReadChunkSize)
(bytes: Buffer).flip()
(Chunk.byteBuffer(bytes) -> fresh).some.pure
} else {
val bytes = buffer.duplicate().asReadOnlyBuffer()
val slice = buffer.slice()
(bytes: Buffer).flip()
(Chunk.byteBuffer(bytes) -> slice).some.pure
}
}
}
}

def write(bytes: Chunk[Byte]): F[Unit] =
lock(writeMutex).surround {
F.delay(bytes.toByteBuffer).flatMap(writeChunk)
}

def writes: Pipe[F, Byte, Nothing] =
_.chunks.foreach(write)
def writes: Pipe[F, Byte, Nothing] = { in =>
Stream.resource(lock(writeMutex)).flatMap { _ =>
in.chunks.foreach(bytes => writeChunk(bytes.toByteBuffer))
}
}
}

private final class AsyncSocket[F[_]](
ch: AsynchronousSocketChannel,
readMutex: Mutex[F],
writeMutex: Mutex[F]
readMutex: Option[Mutex[F]],
writeMutex: Option[Mutex[F]],
maxReadChunkSize: Int
)(implicit F: Async[F])
extends BufferedReads[F](readMutex) {
extends BufferedReads[F](readMutex, writeMutex, maxReadChunkSize) {

protected def readChunk(buffer: ByteBuffer): F[Int] =
F.async[Int] { cb =>
Expand All @@ -120,24 +163,18 @@ private[net] trait SocketCompanionPlatform {
F.delay(Some(endOfInput.voidError))
}

def write(bytes: Chunk[Byte]): F[Unit] = {
def go(buff: ByteBuffer): F[Unit] =
F.async[Int] { cb =>
ch.write(
buff,
null,
new IntCompletionHandler(cb)
)
F.delay(Some(endOfOutput.voidError))
}.flatMap { written =>
if (written >= 0 && buff.remaining() > 0)
go(buff)
else F.unit
}
writeMutex.lock.surround {
F.delay(bytes.toByteBuffer).flatMap(go)
protected def writeChunk(buffer: ByteBuffer): F[Unit] =
F.async[Int] { cb =>
ch.write(
buffer,
null,
new IntCompletionHandler(cb)
)
F.delay(Some(endOfOutput.voidError))
}.flatMap { written =>
if (written < 0 || buffer.remaining() == 0) F.unit
else writeChunk(buffer)
}
}

def localAddress: F[SocketAddress[IpAddress]] =
F.delay(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import cats.effect.std.Mutex
import cats.effect.syntax.all._
import cats.syntax.all._
import com.comcast.ip4s.{IpAddress, SocketAddress}
import fs2.{Chunk, Stream}
import fs2.Stream
import fs2.io.file.{Files, Path}
import fs2.io.net.Socket
import java.nio.ByteBuffer
Expand Down Expand Up @@ -89,29 +89,36 @@ private[unixsocket] trait UnixSocketsCompanionPlatform {
private def makeSocket[F[_]: Async](
ch: SocketChannel
): F[Socket[F]] =
(Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) =>
new AsyncSocket[F](ch, readMutex, writeMutex)
makeSocket(ch, maxReadChunkSize = 16384, withExclusiveReads = true, withExclusiveWrites = true)

private def makeSocket[F[_]](
ch: SocketChannel,
maxReadChunkSize: Int,
withExclusiveReads: Boolean,
withExclusiveWrites: Boolean
)(implicit F: Async[F]): F[Socket[F]] = {
def maybeMutex(maybe: Boolean) = F.defer(if (maybe) Mutex[F].map(Some(_)) else F.pure(None))
(maybeMutex(withExclusiveReads), maybeMutex(withExclusiveWrites)).mapN {
(readMutex, writeMutex) => new AsyncSocket[F](ch, readMutex, writeMutex, maxReadChunkSize)
}
}

private final class AsyncSocket[F[_]](
ch: SocketChannel,
readMutex: Mutex[F],
writeMutex: Mutex[F]
readMutex: Option[Mutex[F]],
writeMutex: Option[Mutex[F]],
maxReadChunkSize: Int
)(implicit F: Async[F])
extends Socket.BufferedReads[F](readMutex) {
extends Socket.BufferedReads[F](readMutex, writeMutex, maxReadChunkSize) {

def readChunk(buff: ByteBuffer): F[Int] =
F.blocking(ch.read(buff)).cancelable(close)
protected def readChunk(buffer: ByteBuffer): F[Int] =
F.blocking(ch.read(buffer)).cancelable(close)

def write(bytes: Chunk[Byte]): F[Unit] = {
def go(buff: ByteBuffer): F[Unit] =
F.blocking(ch.write(buff)).cancelable(close) *>
F.delay(buff.remaining <= 0).ifM(F.unit, go(buff))

writeMutex.lock.surround {
F.delay(bytes.toByteBuffer).flatMap(go)
protected def writeChunk(buffer: ByteBuffer): F[Unit] =
F.blocking(ch.write(buffer)).cancelable(close).flatMap { _ =>
if (buffer.remaining() == 0) F.unit
else writeChunk(buffer)
}
}

def localAddress: F[SocketAddress[IpAddress]] = raiseIpAddressError
def remoteAddress: F[SocketAddress[IpAddress]] = raiseIpAddressError
Expand Down