Skip to content

Commit

Permalink
Set HTTP hostname based on TLS server name
Browse files Browse the repository at this point in the history
The net/http library uses the Host field from the Request object in
order to determine the value of the Host header [1]. In order for the
Prometheus client to support SNI, it needs to set this field to the
value provided in the TLS server name.

[1] golang/go#29865

Signed-off-by: fpetkovski <[email protected]>
  • Loading branch information
fpetkovski committed Dec 21, 2021
1 parent f57586d commit c81c25d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
25 changes: 25 additions & 0 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
http2t.ReadIdleTimeout = time.Minute
}

// hostnameRoundTripper sets the http.Request Host to the value set as the TLS server name
rt = newHostnameRoundTripper(tlsConfig, rt)

// If a authorization_credentials is provided, create a round tripper that will set the
// Authorization header correctly on each request.
if cfg.Authorization != nil && len(cfg.Authorization.Credentials) > 0 {
Expand Down Expand Up @@ -457,6 +460,28 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, newRT)
}

type hostnameRoundTripper struct {
tlsConfig *tls.Config
rt http.RoundTripper
}

func newHostnameRoundTripper(tlsConfig *tls.Config, rt http.RoundTripper) http.RoundTripper {
return &hostnameRoundTripper{
tlsConfig: tlsConfig,
rt: rt,
}
}

func (rt *hostnameRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
if rt.tlsConfig.ServerName != "" {
hostParts := strings.Split(req.Host, ":")
hostParts[0] = rt.tlsConfig.ServerName
req.Host = strings.Join(hostParts, ":")
}
return rt.rt.RoundTrip(req)
}

type authorizationCredentialsRoundTripper struct {
authType string
authCredentials Secret
Expand Down
25 changes: 23 additions & 2 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (
"testing"
"time"

yaml "gopkg.in/yaml.v2"
"gopkg.in/yaml.v2"
)

const (
Expand Down Expand Up @@ -179,6 +179,27 @@ func TestNewClientFromConfig(t *testing.T) {
handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, ExpectedMessage)
},
}, {
clientConfig: HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "test-domain.com",
InsecureSkipVerify: true},
},
handler: func(w http.ResponseWriter, r *http.Request) {
srvAddr := r.Context().Value(http.LocalAddrContextKey).(net.Addr)
srvPort := strings.Split(srvAddr.String(), ":")[1]

expectedHostHeader := "test-domain.com:" + srvPort
actualHostHeader := r.Host
if actualHostHeader != expectedHostHeader {
fmt.Fprintf(w, "The expected Host header (%s) differs from the obtained Host header (%s)",
expectedHostHeader, actualHostHeader)
}
fmt.Fprint(w, ExpectedMessage)
},
}, {
clientConfig: HTTPClientConfig{
BearerToken: BearerToken,
Expand Down Expand Up @@ -512,7 +533,7 @@ func TestCustomIdleConnTimeout(t *testing.T) {
t.Fatalf("Can't create a round-tripper from this config: %+v", cfg)
}

transport, ok := rt.(*http.Transport)
transport, ok := rt.(*hostnameRoundTripper).rt.(*http.Transport)
if !ok {
t.Fatalf("Unexpected transport: %+v", transport)
}
Expand Down

0 comments on commit c81c25d

Please sign in to comment.