Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a custom way to handle SSL exceptions #849

Open
wants to merge 1 commit into
base: series/0.23
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ target/

*~
/.bsp/
# Ignore Scala Metals configuration and logs.
/.metals/
# Ignore the Bloop build server directories.
.bloop/
# Ignore the auto-generated metals configuration file
**/metals.sbt
# Ignore settings for Visual Studio Code
/.vscode
/.ensime
/.idea/
/.idea_modules/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import com.comcast.ip4s.SocketAddress
import org.http4s.blaze.channel._
import org.http4s.blaze.channel.nio1.NIO1SocketServerGroup
import org.http4s.blaze.pipeline.LeafBuilder
import org.http4s.blaze.pipeline.stages.SSLStage
import org.http4s.blaze.pipeline.stages.{SSLStage, SSLStageDefaults}
import org.http4s.blaze.server.BlazeServerBuilder._
import org.http4s.blaze.util.TickWheelExecutor
import org.http4s.blaze.{BuildInfo => BlazeBuildInfo}
Expand Down Expand Up @@ -194,15 +194,37 @@ class BlazeServerBuilder[F[_]] private (
): Self =
copy(sslConfig = new ContextWithClientAuth[F](sslContext, clientAuth))

/** Configures the server with TLS, using the provided `SSLContext` and its
* default `SSLParameters`
*
* @param sslErrorHandler function that runs if an error occurs during the TLS handshake. Default behavior is to log the error.
*/
def withSslContext(
sslContext: SSLContext,
sslErrorHandler: PartialFunction[Throwable, Unit] = PartialFunction.empty,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more beneficial to have a general error handler exposed to users and not only for SSL exceptions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it's worth, Ember's general error handler just came up in http4s/http4s#7399. This is more in the direction that Ember is going, but probably away from the direction I'd prefer to go.

): Self =
copy(sslConfig = new ContextOnly[F](sslContext, sslErrorHandler))

/** Configures the server with TLS, using the provided `SSLContext` and its
* default `SSLParameters`
*/
def withSslContext(sslContext: SSLContext): Self =
copy(sslConfig = new ContextOnly[F](sslContext))
withSslContext(sslContext, PartialFunction.empty)

/** Configures the server with TLS, using the provided `SSLContext` and `SSLParameters`.
*
* @param sslErrorHandler function that runs if an error occurs during the TLS handshake. Default behavior is to log the error.
*/
def withSslContextAndParameters(
sslContext: SSLContext,
sslParameters: SSLParameters,
sslErrorHandler: PartialFunction[Throwable, Unit] = PartialFunction.empty,
): Self =
copy(sslConfig = new ContextWithParameters[F](sslContext, sslParameters, sslErrorHandler))

/** Configures the server with TLS, using the provided `SSLContext` and `SSLParameters`. */
def withSslContextAndParameters(sslContext: SSLContext, sslParameters: SSLParameters): Self =
copy(sslConfig = new ContextWithParameters[F](sslContext, sslParameters))
withSslContextAndParameters(sslContext, sslParameters, PartialFunction.empty)

def withoutSsl: Self =
copy(sslConfig = new NoSsl[F]())
Expand Down Expand Up @@ -277,7 +299,7 @@ class BlazeServerBuilder[F[_]] private (

private def pipelineFactory(
scheduler: TickWheelExecutor,
engineConfig: Option[(SSLContext, SSLEngine => Unit)],
engineConfig: Option[(SSLContextWithExtras, SSLEngine => Unit)],
dispatcher: Dispatcher[F],
)(conn: SocketConnection): Future[LeafBuilder[ByteBuffer]] = {
def requestAttributes(secure: Boolean, optionalSslEngine: Option[SSLEngine]): () => Vault =
Expand Down Expand Up @@ -365,15 +387,16 @@ class BlazeServerBuilder[F[_]] private (
executionContextConfig.getExecutionContext[F].flatMap { executionContext =>
engineConfig match {
case Some((ctx, configure)) =>
val engine = ctx.createSSLEngine()
val engine = ctx.context.createSSLEngine()
engine.setUseClientMode(false)
configure(engine)

val leafBuilder =
if (isHttp2Enabled) http2Stage(executionContext, engine).map(LeafBuilder(_))
else http1Stage(executionContext, secure = true, engine.some).map(LeafBuilder(_))

leafBuilder.map(_.prepend(new SSLStage(engine)))
leafBuilder
.map(_.prepend(new SSLStage(engine, SSLStageDefaults.MaxWrite, ctx.errorHandler)))

case None =>
if (isHttp2Enabled)
Expand Down Expand Up @@ -497,8 +520,17 @@ object BlazeServerBuilder {
private def defaultThreadSelectorFactory: ThreadFactory =
threadFactory(name = n => s"blaze-selector-${n}", daemon = false)

private case class SSLContextWithExtras(
context: SSLContext,
errorHandler: PartialFunction[Throwable, Unit],
)
private object SSLContextWithExtras {
def onlyContext(context: SSLContext): SSLContextWithExtras =
apply(context, PartialFunction.empty)
}

private sealed trait SslConfig[F[_]] {
def makeContext: F[Option[SSLContext]]
def makeContext: F[Option[SSLContextWithExtras]]
def configureEngine(sslEngine: SSLEngine): Unit
def isSecure: Boolean
}
Expand All @@ -511,7 +543,7 @@ object BlazeServerBuilder {
clientAuth: SSLClientAuthMode,
)(implicit F: Sync[F])
extends SslConfig[F] {
def makeContext: F[Option[SSLContext]] =
def makeContext: F[Option[SSLContextWithExtras]] =
F.delay {
val ksStream = new FileInputStream(keyStore.path)
val ks = KeyStore.getInstance("JKS")
Expand Down Expand Up @@ -540,39 +572,59 @@ object BlazeServerBuilder {

val context = SSLContext.getInstance(protocol)
context.init(kmf.getKeyManagers, tmf.orNull, null)
context.some
SSLContextWithExtras(context, PartialFunction.empty).some
}
def configureEngine(engine: SSLEngine): Unit =
configureEngineFromSslClientAuthMode(engine, clientAuth)
def isSecure: Boolean = true
}

private class ContextOnly[F[_]](sslContext: SSLContext)(implicit F: Applicative[F])
private class ContextOnly[F[_]](
sslContext: SSLContext,
sslErrorHandler: PartialFunction[Throwable, Unit],
)(implicit F: Applicative[F])
extends SslConfig[F] {
def makeContext: F[Option[SSLContext]] = F.pure(sslContext.some)

/** Constructor for backwards compatibility */
def this(sslContext: SSLContext)(implicit F: Applicative[F]) =
this(sslContext, PartialFunction.empty)

def makeContext: F[Option[SSLContextWithExtras]] =
F.pure(SSLContextWithExtras(sslContext, sslErrorHandler).some)
def configureEngine(engine: SSLEngine): Unit = ()
def isSecure: Boolean = true
}

private class ContextWithParameters[F[_]](sslContext: SSLContext, sslParameters: SSLParameters)(
implicit F: Applicative[F]
private class ContextWithParameters[F[_]](
sslContext: SSLContext,
sslParameters: SSLParameters,
sslErrorHandler: PartialFunction[Throwable, Unit],
)(implicit
F: Applicative[F]
) extends SslConfig[F] {
def makeContext: F[Option[SSLContext]] = F.pure(sslContext.some)

/** Constructor for backwards compatibility */
def this(sslContext: SSLContext, sslParameters: SSLParameters)(implicit F: Applicative[F]) =
this(sslContext, sslParameters, PartialFunction.empty)

def makeContext: F[Option[SSLContextWithExtras]] =
F.pure(SSLContextWithExtras(sslContext, sslErrorHandler).some)
def configureEngine(engine: SSLEngine): Unit = engine.setSSLParameters(sslParameters)
def isSecure: Boolean = true
}

private class ContextWithClientAuth[F[_]](sslContext: SSLContext, clientAuth: SSLClientAuthMode)(
implicit F: Applicative[F]
) extends SslConfig[F] {
def makeContext: F[Option[SSLContext]] = F.pure(sslContext.some)
def makeContext: F[Option[SSLContextWithExtras]] =
F.pure(SSLContextWithExtras.onlyContext(sslContext).some)
def configureEngine(engine: SSLEngine): Unit =
configureEngineFromSslClientAuthMode(engine, clientAuth)
def isSecure: Boolean = true
}

private class NoSsl[F[_]]()(implicit F: Applicative[F]) extends SslConfig[F] {
def makeContext: F[Option[SSLContext]] = F.pure(None)
def makeContext: F[Option[SSLContextWithExtras]] = F.pure(None)
def configureEngine(engine: SSLEngine): Unit = ()
def isSecure: Boolean = false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,30 @@ private object SSLStage {
private case class SSLFailure(t: Throwable) extends SSLResult
}

final class SSLStage(engine: SSLEngine, maxWrite: Int = 1024 * 1024)
extends MidStage[ByteBuffer, ByteBuffer] {
/** Default values for the [[SSLStage]] constructors.
*
* A separate `object` because the [[SSLStage]] `object` is `private`.
*/
object SSLStageDefaults {

/** Default value for [[SSLStage.maxWrite]] */
final val MaxWrite = 1024 * 1024
}

/** @param maxWrite
* \@see [[SSLStageDefaults.MaxWrite]].
*/
Comment on lines +67 to +69
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like missing scaladoc

final class SSLStage(
engine: SSLEngine,
maxWrite: Int,
sslHandshakeExceptionHandler: PartialFunction[Throwable, Unit]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's expected to handle SSLExceptions only, shouldn't the contract be PF[SSLException, Unit]?

) extends MidStage[ByteBuffer, ByteBuffer] {
import SSLStage._

/** Constructor to keep backwards compatibility with old versions. */
def this(engine: SSLEngine, maxWrite: Int = SSLStageDefaults.MaxWrite) =
this(engine, maxWrite, PartialFunction.empty)

def name: String = "SSLStage"

// We use a serial executor to ensure single threaded behavior. This makes
Expand Down Expand Up @@ -248,13 +268,15 @@ final class SSLStage(engine: SSLEngine, maxWrite: Int = 1024 * 1024)
val start = System.nanoTime
try sslHandshakeLoop(data, r)
catch {
case t: SSLException =>
logger.warn(t)("SSLException in SSL handshake")
handshakeFailure(t)

case NonFatal(t) =>
logger.error(t)("Error in SSL handshake")
handshakeFailure(t)
try
if (sslHandshakeExceptionHandler.isDefinedAt(t)) sslHandshakeExceptionHandler(t)
else
t match {
case t: SSLException => logger.warn(t)("SSLException in SSL handshake")
case _ => logger.error(t)("Error in SSL handshake")
}
finally handshakeFailure(t)
}
logger.trace(s"${engine.##}: sslHandshake completed in ${System.nanoTime - start}ns")
}
Expand Down
Loading