Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom server name resolution #269

Merged
merged 6 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 72 additions & 20 deletions src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct HttpsConnector<T> {
force_https: bool,
http: T,
tls_config: Arc<rustls::ClientConfig>,
override_server_name: Option<String>,
server_name_resolver: Arc<dyn ResolveServerName + Sync + Send>,
}

impl<T> HttpsConnector<T> {
Expand Down Expand Up @@ -90,24 +90,10 @@ where
};

let cfg = self.tls_config.clone();
let mut hostname = match self.override_server_name.as_deref() {
Some(h) => h,
None => dst.host().unwrap_or_default(),
};

// Remove square brackets around IPv6 address.
if let Some(trimmed) = hostname
.strip_prefix('[')
.and_then(|h| h.strip_suffix(']'))
{
hostname = trimmed;
}

let hostname = match ServerName::try_from(hostname) {
Ok(dns_name) => dns_name.to_owned(),
Err(_) => {
let err = io::Error::new(io::ErrorKind::Other, "invalid dnsname");
return Box::pin(async move { Err(Box::new(err).into()) });
let hostname = match self.server_name_resolver.resolve(&dst) {
Ok(hostname) => hostname,
Err(e) => {
return Box::pin(async move { Err(e) });
}
};

Expand Down Expand Up @@ -135,7 +121,7 @@ where
force_https: false,
http,
tls_config: cfg.into(),
override_server_name: None,
server_name_resolver: Arc::new(DefaultServerNameResolver::default()),
}
}
}
Expand All @@ -147,3 +133,69 @@ impl<T> fmt::Debug for HttpsConnector<T> {
.finish()
}
}

/// The default server name resolver, which uses the hostname in the URI.
#[derive(Default)]
pub struct DefaultServerNameResolver(());

impl ResolveServerName for DefaultServerNameResolver {
fn resolve(
&self,
uri: &Uri,
) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> {
let mut hostname = uri.host().unwrap_or_default();

// Remove square brackets around IPv6 address.
if let Some(trimmed) = hostname
.strip_prefix('[')
.and_then(|h| h.strip_suffix(']'))
{
hostname = trimmed;
}

ServerName::try_from(hostname.to_string()).map_err(|e| Box::new(e) as _)
}
}

/// A server name resolver which always returns the same fixed name.
pub struct FixedServerNameResolver {
name: ServerName<'static>,
}

impl FixedServerNameResolver {
/// Creates a new resolver returning the specified name.
pub fn new(name: ServerName<'static>) -> Self {
Self { name }
}
}

impl ResolveServerName for FixedServerNameResolver {
fn resolve(
&self,
_: &Uri,
) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> {
Ok(self.name.clone())
}
}

impl<F, E> ResolveServerName for F
where
F: Fn(&Uri) -> Result<ServerName<'static>, E>,
E: Into<Box<dyn std::error::Error + Sync + Send>>,
{
fn resolve(
&self,
uri: &Uri,
) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> {
self(uri).map_err(Into::into)
}
}

/// A trait implemented by types that can resolve a [`ServerName`] for a request.
pub trait ResolveServerName {
/// Maps a [`Uri`] into a [`ServerName`].
fn resolve(
&self,
uri: &Uri,
) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>>;
}
50 changes: 42 additions & 8 deletions src/connector/builder.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::sync::Arc;

use hyper_util::client::legacy::connect::HttpConnector;
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
use rustls::crypto::CryptoProvider;
use rustls::ClientConfig;

use super::HttpsConnector;
use super::{DefaultServerNameResolver, HttpsConnector, ResolveServerName};
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
use crate::config::ConfigBuilderExt;
use pki_types::ServerName;

/// A builder for an [`HttpsConnector`]
///
Expand Down Expand Up @@ -153,7 +156,7 @@ impl ConnectorBuilder<WantsSchemes> {
ConnectorBuilder(WantsProtocols1 {
tls_config: self.0.tls_config,
https_only: true,
override_server_name: None,
server_name_resolver: None,
})
}

Expand All @@ -165,7 +168,7 @@ impl ConnectorBuilder<WantsSchemes> {
ConnectorBuilder(WantsProtocols1 {
tls_config: self.0.tls_config,
https_only: false,
override_server_name: None,
server_name_resolver: None,
})
}
}
Expand All @@ -177,7 +180,7 @@ impl ConnectorBuilder<WantsSchemes> {
pub struct WantsProtocols1 {
tls_config: ClientConfig,
https_only: bool,
override_server_name: Option<String>,
server_name_resolver: Option<Arc<dyn ResolveServerName + Sync + Send>>,
}

impl WantsProtocols1 {
Expand All @@ -186,7 +189,9 @@ impl WantsProtocols1 {
force_https: self.https_only,
http: conn,
tls_config: std::sync::Arc::new(self.tls_config),
override_server_name: self.override_server_name,
server_name_resolver: self
.server_name_resolver
.unwrap_or_else(|| Arc::new(DefaultServerNameResolver::default())),
}
}

Expand Down Expand Up @@ -237,6 +242,22 @@ impl ConnectorBuilder<WantsProtocols1> {
})
}

/// Override server name for the TLS stack
///
/// By default, for each connection hyper-rustls will extract host portion
/// of the destination URL and verify that server certificate contains
/// this value.
///
/// If this method is called, hyper-rustls will instead use this resolver
/// to compute the value used to verify the server certificate.
cpu marked this conversation as resolved.
Show resolved Hide resolved
pub fn with_server_name_resolver(
mut self,
resolver: impl ResolveServerName + 'static + Sync + Send,
) -> Self {
self.0.server_name_resolver = Some(Arc::new(resolver));
self
}

/// Override server name for the TLS stack
///
/// By default, for each connection hyper-rustls will extract host portion
Expand All @@ -246,9 +267,22 @@ impl ConnectorBuilder<WantsProtocols1> {
/// If this method is called, hyper-rustls will instead verify that server
/// certificate contains `override_server_name`. Domain name included in
/// the URL will not affect certificate validation.
pub fn with_server_name(mut self, override_server_name: String) -> Self {
self.0.override_server_name = Some(override_server_name);
self
#[deprecated(
since = "0.27.1",
note = "use Self::with_server_name_resolver with FixedServerNameResolver instead"
)]
pub fn with_server_name(self, mut override_server_name: String) -> Self {
// remove square brackets around IPv6 address.
if let Some(trimmed) = override_server_name
.strip_prefix('[')
.and_then(|s| s.strip_suffix(']'))
{
override_server_name = trimmed.to_string();
}

self.with_server_name_resolver(move |_: &_| {
ServerName::try_from(override_server_name.clone())
})
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ mod log {

pub use crate::config::ConfigBuilderExt;
pub use crate::connector::builder::ConnectorBuilder as HttpsConnectorBuilder;
pub use crate::connector::HttpsConnector;
pub use crate::connector::{
DefaultServerNameResolver, FixedServerNameResolver, HttpsConnector, ResolveServerName,
};
pub use crate::stream::MaybeHttpsStream;

/// The various states of the [`HttpsConnectorBuilder`]
Expand Down
Loading