Skip to content

Commit

Permalink
Set HTTP hostname based on TLS server name
Browse files Browse the repository at this point in the history
The net/http library uses the Host field from the Request object in
order to determine the value of the Host header [1]. In order for the
Prometheus client to support SNI, it needs to set this field to the
value provided in the TLS server name.

[1] golang/go#29865

Signed-off-by: fpetkovski <[email protected]>
  • Loading branch information
fpetkovski committed Dec 21, 2021
1 parent f57586d commit 038724f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
18 changes: 18 additions & 0 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import (
"sync"
"time"

"github.com/prometheus/client_golang/prometheus/promhttp"

"github.com/mwitkow/go-conntrack"
"golang.org/x/net/http2"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -418,6 +420,8 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
http2t.ReadIdleTimeout = time.Minute
}

rt = promhttp.RoundTripperFunc(HostHeaderRoundTripper(tlsConfig, rt))

// If a authorization_credentials is provided, create a round tripper that will set the
// Authorization header correctly on each request.
if cfg.Authorization != nil && len(cfg.Authorization.Credentials) > 0 {
Expand Down Expand Up @@ -457,6 +461,20 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, newRT)
}

func HostHeaderRoundTripper(config *tls.Config, rt http.RoundTripper) func(r *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
if config != nil && config.ServerName != "" {
hostParts := strings.Split(req.Host, ":")
hostParts[0] = config.ServerName
hostname := strings.Join(hostParts, ":")

req.Host = hostname
}
return rt.RoundTrip(req)
}
}

type authorizationCredentialsRoundTripper struct {
authType string
authCredentials Secret
Expand Down
23 changes: 22 additions & 1 deletion config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (
"testing"
"time"

yaml "gopkg.in/yaml.v2"
"gopkg.in/yaml.v2"
)

const (
Expand Down Expand Up @@ -179,6 +179,27 @@ func TestNewClientFromConfig(t *testing.T) {
handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, ExpectedMessage)
},
}, {
clientConfig: HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "test-domain.com",
InsecureSkipVerify: true},
},
handler: func(w http.ResponseWriter, r *http.Request) {
srvAddr := r.Context().Value(http.LocalAddrContextKey).(net.Addr)
srvPort := strings.Split(srvAddr.String(), ":")[1]

expectedHostHeader := "test-domain.com:" + srvPort
actualHostHeader := r.Host
if actualHostHeader != expectedHostHeader {
fmt.Fprintf(w, "The expected Host header (%s) differs from the obtained Host header (%s)",
expectedHostHeader, actualHostHeader)
}
fmt.Fprint(w, ExpectedMessage)
},
}, {
clientConfig: HTTPClientConfig{
BearerToken: BearerToken,
Expand Down

0 comments on commit 038724f

Please sign in to comment.