diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index b36f71c3..5ceb5909 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -24,8 +24,8 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@9bb56186c3b09b4f86b1c65136769dd318469633 # v4.1.2 - - name: install Go + uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6 + - name: Install Go uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5.0.0 with: go-version: 1.22.x @@ -33,6 +33,7 @@ jobs: run: sudo apt-get update && sudo apt-get -y install libsnmp-dev if: github.repository == 'prometheus/snmp_exporter' - name: Lint - uses: golangci/golangci-lint-action@9d1e0624a798bb64f6c3cea93db47765312263dc # v5.1.0 + uses: golangci/golangci-lint-action@a4f60bb28d35aeee14e6880718e0c85ff1882e64 # v6.0.1 with: - version: v1.56.2 + args: --verbose + version: v1.59.0 diff --git a/Makefile.common b/Makefile.common index 0e9ace29..16172923 100644 --- a/Makefile.common +++ b/Makefile.common @@ -61,7 +61,7 @@ PROMU_URL := https://github.com/prometheus/promu/releases/download/v$(PROMU_ SKIP_GOLANGCI_LINT := GOLANGCI_LINT := GOLANGCI_LINT_OPTS ?= -GOLANGCI_LINT_VERSION ?= v1.56.2 +GOLANGCI_LINT_VERSION ?= v1.59.0 # golangci-lint only supports linux, darwin and windows platforms on i386/amd64/arm64. # windows isn't included here because of the path separator being different. ifeq ($(GOHOSTOS),$(filter $(GOHOSTOS),linux darwin)) diff --git a/config/http_config.go b/config/http_config.go index 3241b1ad..75e38bb8 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -20,6 +20,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -130,8 +131,12 @@ func (tv *TLSVersion) String() string { type BasicAuth struct { Username string `yaml:"username" json:"username"` UsernameFile string `yaml:"username_file,omitempty" json:"username_file,omitempty"` + // UsernameRef is the name of the secret within the secret manager to use as the username. + UsernameRef string `yaml:"username_ref,omitempty" json:"username_ref,omitempty"` Password Secret `yaml:"password,omitempty" json:"password,omitempty"` PasswordFile string `yaml:"password_file,omitempty" json:"password_file,omitempty"` + // PasswordRef is the name of the secret within the secret manager to use as the password. + PasswordRef string `yaml:"password_ref,omitempty" json:"password_ref,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -148,6 +153,8 @@ type Authorization struct { Type string `yaml:"type,omitempty" json:"type,omitempty"` Credentials Secret `yaml:"credentials,omitempty" json:"credentials,omitempty"` CredentialsFile string `yaml:"credentials_file,omitempty" json:"credentials_file,omitempty"` + // CredentialsRef is the name of the secret within the secret manager to use as credentials. + CredentialsRef string `yaml:"credentials_ref,omitempty" json:"credentials_ref,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -224,14 +231,17 @@ func (u URL) MarshalJSON() ([]byte, error) { // OAuth2 is the oauth2 client configuration. type OAuth2 struct { - ClientID string `yaml:"client_id" json:"client_id"` - ClientSecret Secret `yaml:"client_secret" json:"client_secret"` - ClientSecretFile string `yaml:"client_secret_file" json:"client_secret_file"` - Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"` - TokenURL string `yaml:"token_url" json:"token_url"` - EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"` - TLSConfig TLSConfig `yaml:"tls_config,omitempty"` - ProxyConfig `yaml:",inline"` + ClientID string `yaml:"client_id" json:"client_id"` + ClientSecret Secret `yaml:"client_secret" json:"client_secret"` + ClientSecretFile string `yaml:"client_secret_file" json:"client_secret_file"` + // ClientSecretRef is the name of the secret within the secret manager to use as the client + // secret. + ClientSecretRef string `yaml:"client_secret_ref" json:"client_secret_ref"` + Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"` + TokenURL string `yaml:"token_url" json:"token_url"` + EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"` + TLSConfig TLSConfig `yaml:"tls_config,omitempty"` + ProxyConfig `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface @@ -253,12 +263,12 @@ func (o *OAuth2) UnmarshalJSON(data []byte) error { } // SetDirectory joins any relative file paths with dir. -func (a *OAuth2) SetDirectory(dir string) { - if a == nil { +func (o *OAuth2) SetDirectory(dir string) { + if o == nil { return } - a.ClientSecretFile = JoinDir(dir, a.ClientSecretFile) - a.TLSConfig.SetDirectory(dir) + o.ClientSecretFile = JoinDir(dir, o.ClientSecretFile) + o.TLSConfig.SetDirectory(dir) } // LoadHTTPConfig parses the YAML input s into a HTTPClientConfig. @@ -329,6 +339,18 @@ func (c *HTTPClientConfig) SetDirectory(dir string) { c.BearerTokenFile = JoinDir(dir, c.BearerTokenFile) } +// nonZeroCount returns the amount of values that are non-zero. +func nonZeroCount[T comparable](values ...T) int { + count := 0 + var zero T + for _, value := range values { + if value != zero { + count += 1 + } + } + return count +} + // Validate validates the HTTPClientConfig to check only one of BearerToken, // BasicAuth and BearerTokenFile is configured. It also validates that ProxyURL // is set if ProxyConnectHeader is set. @@ -340,17 +362,17 @@ func (c *HTTPClientConfig) Validate() error { if (c.BasicAuth != nil || c.OAuth2 != nil) && (len(c.BearerToken) > 0 || len(c.BearerTokenFile) > 0) { return fmt.Errorf("at most one of basic_auth, oauth2, bearer_token & bearer_token_file must be configured") } - if c.BasicAuth != nil && (string(c.BasicAuth.Username) != "" && c.BasicAuth.UsernameFile != "") { - return fmt.Errorf("at most one of basic_auth username & username_file must be configured") + if c.BasicAuth != nil && nonZeroCount(string(c.BasicAuth.Username) != "", c.BasicAuth.UsernameFile != "", c.BasicAuth.UsernameRef != "") > 1 { + return fmt.Errorf("at most one of basic_auth username, username_file & username_ref must be configured") } - if c.BasicAuth != nil && (string(c.BasicAuth.Password) != "" && c.BasicAuth.PasswordFile != "") { - return fmt.Errorf("at most one of basic_auth password & password_file must be configured") + if c.BasicAuth != nil && nonZeroCount(string(c.BasicAuth.Password) != "", c.BasicAuth.PasswordFile != "", c.BasicAuth.PasswordRef != "") > 1 { + return fmt.Errorf("at most one of basic_auth password, password_file & password_ref must be configured") } if c.Authorization != nil { if len(c.BearerToken) > 0 || len(c.BearerTokenFile) > 0 { return fmt.Errorf("authorization is not compatible with bearer_token & bearer_token_file") } - if string(c.Authorization.Credentials) != "" && c.Authorization.CredentialsFile != "" { + if nonZeroCount(string(c.Authorization.Credentials) != "", c.Authorization.CredentialsFile != "", c.Authorization.CredentialsRef != "") > 1 { return fmt.Errorf("at most one of authorization credentials & credentials_file must be configured") } c.Authorization.Type = strings.TrimSpace(c.Authorization.Type) @@ -385,8 +407,8 @@ func (c *HTTPClientConfig) Validate() error { if len(c.OAuth2.TokenURL) == 0 { return fmt.Errorf("oauth2 token_url must be configured") } - if len(c.OAuth2.ClientSecret) > 0 && len(c.OAuth2.ClientSecretFile) > 0 { - return fmt.Errorf("at most one of oauth2 client_secret & client_secret_file must be configured") + if nonZeroCount(len(c.OAuth2.ClientSecret) > 0, len(c.OAuth2.ClientSecretFile) > 0, len(c.OAuth2.ClientSecretRef) > 0) > 1 { + return fmt.Errorf("at most one of oauth2 client_secret, client_secret_file & client_secret_ref must be configured") } } if err := c.ProxyConfig.Validate(); err != nil { @@ -437,50 +459,78 @@ type httpClientOptions struct { idleConnTimeout time.Duration userAgent string host string + secretManager SecretManager } // HTTPClientOption defines an option that can be applied to the HTTP client. -type HTTPClientOption func(options *httpClientOptions) +type HTTPClientOption interface { + applyToHTTPClientOptions(options *httpClientOptions) +} + +type httpClientOptionFunc func(options *httpClientOptions) + +func (f httpClientOptionFunc) applyToHTTPClientOptions(options *httpClientOptions) { + f(options) +} // WithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`. func WithDialContextFunc(fn DialContextFunc) HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.dialContextFunc = fn - } + }) } // WithKeepAlivesDisabled allows to disable HTTP keepalive. func WithKeepAlivesDisabled() HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.keepAlivesEnabled = false - } + }) } // WithHTTP2Disabled allows to disable HTTP2. func WithHTTP2Disabled() HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.http2Enabled = false - } + }) } // WithIdleConnTimeout allows setting the idle connection timeout. func WithIdleConnTimeout(timeout time.Duration) HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.idleConnTimeout = timeout - } + }) } // WithUserAgent allows setting the user agent. func WithUserAgent(ua string) HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.userAgent = ua - } + }) } // WithHost allows setting the host header. func WithHost(host string) HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.host = host + }) +} + +type secretManagerOption struct { + secretManager SecretManager +} + +func (s *secretManagerOption) applyToHTTPClientOptions(opts *httpClientOptions) { + opts.secretManager = s.secretManager +} + +func (s *secretManagerOption) applyToTLSConfigOptions(opts *tlsConfigOptions) { + opts.secretManager = s.secretManager +} + +// WithSecretManager allows setting the secret manager. +func WithSecretManager(manager SecretManager) *secretManagerOption { + return &secretManagerOption{ + secretManager: manager, } } @@ -510,9 +560,16 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClie // given config.HTTPClientConfig and config.HTTPClientOption. // The name is used as go-conntrack metric label. func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) { + return NewRoundTripperFromConfigWithContext(context.Background(), cfg, name, optFuncs...) +} + +// NewRoundTripperFromConfigWithContext returns a new HTTP RoundTripper configured for the +// given config.HTTPClientConfig and config.HTTPClientOption. +// The name is used as go-conntrack metric label. +func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) { opts := defaultHTTPClientOptions - for _, f := range optFuncs { - f(&opts) + for _, opt := range optFuncs { + opt.applyToHTTPClientOptions(&opts) } var dialContext func(ctx context.Context, network, addr string) (net.Conn, error) @@ -560,25 +617,41 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT // If a authorization_credentials is provided, create a round tripper that will set the // Authorization header correctly on each request. - if cfg.Authorization != nil && len(cfg.Authorization.CredentialsFile) > 0 { - rt = NewAuthorizationCredentialsFileRoundTripper(cfg.Authorization.Type, cfg.Authorization.CredentialsFile, rt) - } else if cfg.Authorization != nil { - rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, cfg.Authorization.Credentials, rt) + if cfg.Authorization != nil { + credentialsSecret, err := toSecret(opts.secretManager, cfg.Authorization.Credentials, cfg.Authorization.CredentialsFile, cfg.Authorization.CredentialsRef) + if err != nil { + return nil, fmt.Errorf("unable to use credentials: %w", err) + } + rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, credentialsSecret, rt) } // Backwards compatibility, be nice with importers who would not have // called Validate(). - if len(cfg.BearerToken) > 0 { - rt = NewAuthorizationCredentialsRoundTripper("Bearer", cfg.BearerToken, rt) - } else if len(cfg.BearerTokenFile) > 0 { - rt = NewAuthorizationCredentialsFileRoundTripper("Bearer", cfg.BearerTokenFile, rt) + if len(cfg.BearerToken) > 0 || len(cfg.BearerTokenFile) > 0 { + bearerSecret, err := toSecret(opts.secretManager, cfg.BearerToken, cfg.BearerTokenFile, "") + if err != nil { + return nil, fmt.Errorf("unable to use bearer token: %w", err) + } + rt = NewAuthorizationCredentialsRoundTripper("Bearer", bearerSecret, rt) } if cfg.BasicAuth != nil { - rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.UsernameFile, cfg.BasicAuth.PasswordFile, rt) + usernameSecret, err := toSecret(opts.secretManager, Secret(cfg.BasicAuth.Username), cfg.BasicAuth.UsernameFile, cfg.BasicAuth.UsernameRef) + if err != nil { + return nil, fmt.Errorf("unable to use username: %w", err) + } + passwordSecret, err := toSecret(opts.secretManager, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, cfg.BasicAuth.PasswordRef) + if err != nil { + return nil, fmt.Errorf("unable to use password: %w", err) + } + rt = NewBasicAuthRoundTripper(usernameSecret, passwordSecret, rt) } if cfg.OAuth2 != nil { - rt = NewOAuth2RoundTripper(cfg.OAuth2, rt, &opts) + clientSecret, err := toSecret(opts.secretManager, cfg.OAuth2.ClientSecret, cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef) + if err != nil { + return nil, fmt.Errorf("unable to use client secret: %w", err) + } + rt = NewOAuth2RoundTripper(clientSecret, cfg.OAuth2, rt, &opts) } if cfg.HTTPHeaders != nil { @@ -597,115 +670,183 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT return rt, nil } - tlsConfig, err := NewTLSConfig(&cfg.TLSConfig) + tlsConfig, err := NewTLSConfig(&cfg.TLSConfig, WithSecretManager(opts.secretManager)) if err != nil { return nil, err } - if len(cfg.TLSConfig.CAFile) == 0 { + tlsSettings, err := cfg.TLSConfig.roundTripperSettings(opts.secretManager) + if err != nil { + return nil, err + } + if tlsSettings.CA == nil || tlsSettings.CA.immutable() { // No need for a RoundTripper that reloads the CA file automatically. return newRT(tlsConfig) } - return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.roundTripperSettings(), newRT) + return NewTLSRoundTripperWithContext(ctx, tlsConfig, tlsSettings, newRT) } -type authorizationCredentialsRoundTripper struct { - authType string - authCredentials Secret - rt http.RoundTripper +// SecretManager manages secret data mapped to names known as "references" or "refs". +type SecretManager interface { + // Fetch returns the secret data given a secret name indicated by `secretRef`. + Fetch(ctx context.Context, secretRef string) (string, error) } -// NewAuthorizationCredentialsRoundTripper adds the provided credentials to a -// request unless the authorization header has already been set. -func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials Secret, rt http.RoundTripper) http.RoundTripper { - return &authorizationCredentialsRoundTripper{authType, authCredentials, rt} +type secret interface { + fetch(ctx context.Context) (string, error) + description() string + immutable() bool } -func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if len(req.Header.Get("Authorization")) == 0 { - req = cloneRequest(req) - req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, string(rt.authCredentials))) +type inlineSecret struct { + text string +} + +func (s *inlineSecret) fetch(context.Context) (string, error) { + return s.text, nil +} + +func (s *inlineSecret) description() string { + return "inline" +} + +func (s *inlineSecret) immutable() bool { + return true +} + +type fileSecret struct { + file string +} + +func (s *fileSecret) fetch(ctx context.Context) (string, error) { + fileBytes, err := os.ReadFile(s.file) + if err != nil { + return "", fmt.Errorf("unable to read file %s: %w", s.file, err) } - return rt.rt.RoundTrip(req) + return strings.TrimSpace(string(fileBytes)), nil } -func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() { - if ci, ok := rt.rt.(closeIdler); ok { - ci.CloseIdleConnections() +func (s *fileSecret) description() string { + return fmt.Sprintf("file %s", s.file) +} + +func (s *fileSecret) immutable() bool { + return false +} + +// refSecret fetches a single secret from a secret manager. +type refSecret struct { + ref string + manager SecretManager // manager is expected to be not nil. +} + +func (s *refSecret) fetch(ctx context.Context) (string, error) { + return s.manager.Fetch(ctx, s.ref) +} + +func (s *refSecret) description() string { + return fmt.Sprintf("ref %s", s.ref) +} + +func (s *refSecret) immutable() bool { + return false +} + +// toSecret returns a secret from one of the given sources, assuming exactly +// one or none of the sources are provided. +func toSecret(secretManager SecretManager, text Secret, file, ref string) (secret, error) { + if text != "" { + return &inlineSecret{ + text: string(text), + }, nil + } + if file != "" { + return &fileSecret{ + file: file, + }, nil } + if ref != "" { + if secretManager == nil { + return nil, errors.New("cannot use secret ref without manager") + } + return &refSecret{ + ref: ref, + manager: secretManager, + }, nil + } + return nil, nil } -type authorizationCredentialsFileRoundTripper struct { - authType string - authCredentialsFile string - rt http.RoundTripper +type authorizationCredentialsRoundTripper struct { + authType string + authCredentials secret + rt http.RoundTripper } -// NewAuthorizationCredentialsFileRoundTripper adds the authorization -// credentials read from the provided file to a request unless the authorization -// header has already been set. This file is read for every request. -func NewAuthorizationCredentialsFileRoundTripper(authType, authCredentialsFile string, rt http.RoundTripper) http.RoundTripper { - return &authorizationCredentialsFileRoundTripper{authType, authCredentialsFile, rt} +// NewAuthorizationCredentialsRoundTripper adds the authorization credentials +// read from the provided secret to a request unless the authorization header +// has already been set. +func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials secret, rt http.RoundTripper) http.RoundTripper { + return &authorizationCredentialsRoundTripper{authType, authCredentials, rt} } -func (rt *authorizationCredentialsFileRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if len(req.Header.Get("Authorization")) == 0 { - b, err := os.ReadFile(rt.authCredentialsFile) +func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if len(req.Header.Get("Authorization")) != 0 { + return rt.rt.RoundTrip(req) + } + + var authCredentials string + if rt.authCredentials != nil { + var err error + authCredentials, err = rt.authCredentials.fetch(req.Context()) if err != nil { - return nil, fmt.Errorf("unable to read authorization credentials file %s: %w", rt.authCredentialsFile, err) + return nil, fmt.Errorf("unable to read authorization credentials: %w", err) } - authCredentials := strings.TrimSpace(string(b)) - - req = cloneRequest(req) - req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials)) } + req = cloneRequest(req) + req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials)) + return rt.rt.RoundTrip(req) } -func (rt *authorizationCredentialsFileRoundTripper) CloseIdleConnections() { +func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() { if ci, ok := rt.rt.(closeIdler); ok { ci.CloseIdleConnections() } } type basicAuthRoundTripper struct { - username string - password Secret - usernameFile string - passwordFile string - rt http.RoundTripper + username secret + password secret + rt http.RoundTripper } // NewBasicAuthRoundTripper will apply a BASIC auth authorization header to a request unless it has // already been set. -func NewBasicAuthRoundTripper(username string, password Secret, usernameFile, passwordFile string, rt http.RoundTripper) http.RoundTripper { - return &basicAuthRoundTripper{username, password, usernameFile, passwordFile, rt} +func NewBasicAuthRoundTripper(username secret, password secret, rt http.RoundTripper) http.RoundTripper { + return &basicAuthRoundTripper{username, password, rt} } func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - var username string - var password string if len(req.Header.Get("Authorization")) != 0 { return rt.rt.RoundTrip(req) } - if rt.usernameFile != "" { - usernameBytes, err := os.ReadFile(rt.usernameFile) + var username string + var password string + if rt.username != nil { + var err error + username, err = rt.username.fetch(req.Context()) if err != nil { - return nil, fmt.Errorf("unable to read basic auth username file %s: %w", rt.usernameFile, err) + return nil, fmt.Errorf("unable to read basic auth username: %w", err) } - username = strings.TrimSpace(string(usernameBytes)) - } else { - username = rt.username } - if rt.passwordFile != "" { - passwordBytes, err := os.ReadFile(rt.passwordFile) + if rt.password != nil { + var err error + password, err = rt.password.fetch(req.Context()) if err != nil { - return nil, fmt.Errorf("unable to read basic auth password file %s: %w", rt.passwordFile, err) + return nil, fmt.Errorf("unable to read basic auth password: %w", err) } - password = strings.TrimSpace(string(passwordBytes)) - } else { - password = string(rt.password) } req = cloneRequest(req) req.SetBasicAuth(username, password) @@ -719,104 +860,118 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() { } type oauth2RoundTripper struct { - config *OAuth2 - rt http.RoundTripper - next http.RoundTripper - secret string - mtx sync.RWMutex - opts *httpClientOptions - client *http.Client + mtx sync.RWMutex + lastRT *oauth2.Transport + lastSecret string + + // Required for interaction with Oauth2 server. + config *OAuth2 + clientSecret secret + opts *httpClientOptions + client *http.Client } -func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { +func NewOAuth2RoundTripper(clientSecret secret, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { + if clientSecret == nil { + clientSecret = &inlineSecret{text: ""} + } + return &oauth2RoundTripper{ config: config, - next: next, - opts: opts, + // A correct tokenSource will be added later on. + lastRT: &oauth2.Transport{Base: next}, + opts: opts, + clientSecret: clientSecret, } } -func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - var ( - secret string - changed bool - ) +func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, secret string) (client *http.Client, source oauth2.TokenSource, err error) { + tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, WithSecretManager(rt.opts.secretManager)) + if err != nil { + return nil, nil, err + } - if rt.config.ClientSecretFile != "" { - data, err := os.ReadFile(rt.config.ClientSecretFile) + tlsTransport := func(tlsConfig *tls.Config) (http.RoundTripper, error) { + return &http.Transport{ + TLSClientConfig: tlsConfig, + Proxy: rt.config.ProxyConfig.Proxy(), + ProxyConnectHeader: rt.config.ProxyConfig.GetProxyConnectHeader(), + DisableKeepAlives: !rt.opts.keepAlivesEnabled, + MaxIdleConns: 20, + MaxIdleConnsPerHost: 1, // see https://github.com/golang/go/issues/13801 + IdleConnTimeout: 10 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, nil + } + + var t http.RoundTripper + tlsSettings, err := rt.config.TLSConfig.roundTripperSettings(rt.opts.secretManager) + if err != nil { + return nil, nil, err + } + if tlsSettings.CA == nil || tlsSettings.CA.immutable() { + t, _ = tlsTransport(tlsConfig) + } else { + t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport) if err != nil { - return nil, fmt.Errorf("unable to read oauth2 client secret file %s: %w", rt.config.ClientSecretFile, err) + return nil, nil, err } - secret = strings.TrimSpace(string(data)) - rt.mtx.RLock() - changed = secret != rt.secret - rt.mtx.RUnlock() - } else { - // Either an inline secret or nothing (use an empty string) was provided. - secret = string(rt.config.ClientSecret) } - if changed || rt.rt == nil { - config := &clientcredentials.Config{ - ClientID: rt.config.ClientID, - ClientSecret: secret, - Scopes: rt.config.Scopes, - TokenURL: rt.config.TokenURL, - EndpointParams: mapToValues(rt.config.EndpointParams), - } + if ua := req.UserAgent(); ua != "" { + t = NewUserAgentRoundTripper(ua, t) + } - tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig) - if err != nil { - return nil, err - } + config := &clientcredentials.Config{ + ClientID: rt.config.ClientID, + ClientSecret: secret, + Scopes: rt.config.Scopes, + TokenURL: rt.config.TokenURL, + EndpointParams: mapToValues(rt.config.EndpointParams), + } + client = &http.Client{Transport: t} + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) + return client, config.TokenSource(ctx), nil +} - tlsTransport := func(tlsConfig *tls.Config) (http.RoundTripper, error) { - return &http.Transport{ - TLSClientConfig: tlsConfig, - Proxy: rt.config.ProxyConfig.Proxy(), - ProxyConnectHeader: rt.config.ProxyConfig.GetProxyConnectHeader(), - DisableKeepAlives: !rt.opts.keepAlivesEnabled, - MaxIdleConns: 20, - MaxIdleConnsPerHost: 1, // see https://github.com/golang/go/issues/13801 - IdleConnTimeout: 10 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - }, nil - } +func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + var ( + secret string + needsInit bool + ) - var t http.RoundTripper - if len(rt.config.TLSConfig.CAFile) == 0 { - t, _ = tlsTransport(tlsConfig) - } else { - t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.roundTripperSettings(), tlsTransport) + rt.mtx.RLock() + secret = rt.lastSecret + needsInit = rt.lastRT.Source == nil + rt.mtx.RUnlock() + + // Fetch the secret if it's our first run or always if the secret can change. + if !rt.clientSecret.immutable() || needsInit { + newSecret, err := rt.clientSecret.fetch(req.Context()) + if err != nil { + return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err) + } + if newSecret != secret || needsInit { + // Secret changed or it's a first run. Rebuilt oauth2 setup. + client, source, err := rt.newOauth2TokenSource(req, newSecret) if err != nil { return nil, err } - } - - if ua := req.UserAgent(); ua != "" { - t = NewUserAgentRoundTripper(ua, t) - } - - client := &http.Client{Transport: t} - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) - tokenSource := config.TokenSource(ctx) - rt.mtx.Lock() - rt.secret = secret - rt.rt = &oauth2.Transport{ - Base: rt.next, - Source: tokenSource, - } - if rt.client != nil { - rt.client.CloseIdleConnections() + rt.mtx.Lock() + rt.lastSecret = secret + rt.lastRT.Source = source + if rt.client != nil { + rt.client.CloseIdleConnections() + } + rt.client = client + rt.mtx.Unlock() } - rt.client = client - rt.mtx.Unlock() } rt.mtx.RLock() - currentRT := rt.rt + currentRT := rt.lastRT rt.mtx.RUnlock() return currentRT.RoundTrip(req) } @@ -825,7 +980,7 @@ func (rt *oauth2RoundTripper) CloseIdleConnections() { if rt.client != nil { rt.client.CloseIdleConnections() } - if ci, ok := rt.next.(closeIdler); ok { + if ci, ok := rt.lastRT.Base.(closeIdler); ok { ci.CloseIdleConnections() } } @@ -853,8 +1008,27 @@ func cloneRequest(r *http.Request) *http.Request { return r2 } +type tlsConfigOptions struct { + secretManager SecretManager +} + +// TLSConfigOption defines an option that can be applied to the HTTP client. +type TLSConfigOption interface { + applyToTLSConfigOptions(options *tlsConfigOptions) +} + // NewTLSConfig creates a new tls.Config from the given TLSConfig. -func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { +func NewTLSConfig(cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, error) { + return NewTLSConfigWithContext(context.Background(), cfg, optFuncs...) +} + +// NewTLSConfigWithContext creates a new tls.Config from the given TLSConfig. +func NewTLSConfigWithContext(ctx context.Context, cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, error) { + opts := tlsConfigOptions{} + for _, opt := range optFuncs { + opt.applyToTLSConfigOptions(&opts) + } + if err := cfg.Validate(); err != nil { return nil, err } @@ -873,17 +1047,17 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { // If a CA cert is provided then let's read it in so we can validate the // scrape target's certificate properly. - if len(cfg.CA) > 0 { - if !updateRootCA(tlsConfig, []byte(cfg.CA)) { - return nil, fmt.Errorf("unable to use inline CA cert") - } - } else if len(cfg.CAFile) > 0 { - b, err := readCAFile(cfg.CAFile) + caSecret, err := toSecret(opts.secretManager, Secret(cfg.CA), cfg.CAFile, cfg.CARef) + if err != nil { + return nil, fmt.Errorf("unable to use CA cert: %w", err) + } + if caSecret != nil { + ca, err := caSecret.fetch(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to read CA cert: %w", err) } - if !updateRootCA(tlsConfig, b) { - return nil, fmt.Errorf("unable to use specified CA cert %s", cfg.CAFile) + if !updateRootCA(tlsConfig, []byte(ca)) { + return nil, fmt.Errorf("unable to use specified CA cert %s", caSecret.description()) } } @@ -894,10 +1068,16 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { // If a client cert & key is provided then configure TLS config accordingly. if cfg.usingClientCert() && cfg.usingClientKey() { // Verify that client cert and key are valid. - if _, err := cfg.getClientCertificate(nil); err != nil { + if _, err := cfg.getClientCertificate(ctx, opts.secretManager); err != nil { return nil, err } - tlsConfig.GetClientCertificate = cfg.getClientCertificate + tlsConfig.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + var ctx context.Context + if cri != nil { + ctx = cri.Context() + } + return cfg.getClientCertificate(ctx, opts.secretManager) + } } return tlsConfig, nil @@ -917,6 +1097,15 @@ type TLSConfig struct { CertFile string `yaml:"cert_file,omitempty" json:"cert_file,omitempty"` // The client key file for the targets. KeyFile string `yaml:"key_file,omitempty" json:"key_file,omitempty"` + // CARef is the name of the secret within the secret manager to use as the CA cert for the + // targets. + CARef string `yaml:"ca_ref,omitempty" json:"ca_ref,omitempty"` + // CertRef is the name of the secret within the secret manager to use as the client cert for + // the targets. + CertRef string `yaml:"cert_ref,omitempty" json:"cert_ref,omitempty"` + // KeyRef is the name of the secret within the secret manager to use as the client key for + // the targets. + KeyRef string `yaml:"key_ref,omitempty" json:"key_ref,omitempty"` // Used to verify the hostname for the targets. ServerName string `yaml:"server_name,omitempty" json:"server_name,omitempty"` // Disable target certificate validation. @@ -950,13 +1139,13 @@ func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { // file-based fields for the TLS CA, client certificate, and client key are // used. func (c *TLSConfig) Validate() error { - if len(c.CA) > 0 && len(c.CAFile) > 0 { - return fmt.Errorf("at most one of ca and ca_file must be configured") + if nonZeroCount(len(c.CA) > 0, len(c.CAFile) > 0, len(c.CARef) > 0) > 1 { + return fmt.Errorf("at most one of ca, ca_file & ca_ref must be configured") } - if len(c.Cert) > 0 && len(c.CertFile) > 0 { - return fmt.Errorf("at most one of cert and cert_file must be configured") + if nonZeroCount(len(c.Cert) > 0, len(c.CertFile) > 0, len(c.CertRef) > 0) > 1 { + return fmt.Errorf("at most one of cert, cert_file & cert_ref must be configured") } - if len(c.Key) > 0 && len(c.KeyFile) > 0 { + if nonZeroCount(len(c.Key) > 0, len(c.KeyFile) > 0, len(c.KeyRef) > 0) > 1 { return fmt.Errorf("at most one of key and key_file must be configured") } @@ -970,66 +1159,70 @@ func (c *TLSConfig) Validate() error { } func (c *TLSConfig) usingClientCert() bool { - return len(c.Cert) > 0 || len(c.CertFile) > 0 + return len(c.Cert) > 0 || len(c.CertFile) > 0 || len(c.CertRef) > 0 } func (c *TLSConfig) usingClientKey() bool { - return len(c.Key) > 0 || len(c.KeyFile) > 0 + return len(c.Key) > 0 || len(c.KeyFile) > 0 || len(c.KeyRef) > 0 } -func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings { - return TLSRoundTripperSettings{ - CA: c.CA, - CAFile: c.CAFile, - Cert: c.Cert, - CertFile: c.CertFile, - Key: string(c.Key), - KeyFile: c.KeyFile, +func (c *TLSConfig) roundTripperSettings(secretManager SecretManager) (TLSRoundTripperSettings, error) { + ca, err := toSecret(secretManager, Secret(c.CA), c.CAFile, c.CARef) + if err != nil { + return TLSRoundTripperSettings{}, err + } + cert, err := toSecret(secretManager, Secret(c.Cert), c.CertFile, c.CertRef) + if err != nil { + return TLSRoundTripperSettings{}, err } + key, err := toSecret(secretManager, c.Key, c.KeyFile, c.KeyRef) + if err != nil { + return TLSRoundTripperSettings{}, err + } + return TLSRoundTripperSettings{ + CA: ca, + Cert: cert, + Key: key, + }, nil } -// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate. -func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { +// getClientCertificate reads the pair of client cert and key and returns a tls.Certificate. +func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager SecretManager) (*tls.Certificate, error) { var ( - certData, keyData []byte + certData, keyData string err error ) - if c.CertFile != "" { - certData, err = os.ReadFile(c.CertFile) + certSecret, err := toSecret(secretManager, Secret(c.Cert), c.CertFile, c.CertRef) + if err != nil { + return nil, fmt.Errorf("unable to use client cert: %w", err) + } + if certSecret != nil { + certData, err = certSecret.fetch(ctx) if err != nil { - return nil, fmt.Errorf("unable to read specified client cert (%s): %w", c.CertFile, err) + return nil, fmt.Errorf("unable to read specified client cert: %w", err) } - } else { - certData = []byte(c.Cert) } - if c.KeyFile != "" { - keyData, err = os.ReadFile(c.KeyFile) + keySecret, err := toSecret(secretManager, Secret(c.Key), c.KeyFile, c.KeyRef) + if err != nil { + return nil, fmt.Errorf("unable to use client key: %w", err) + } + if keySecret != nil { + keyData, err = keySecret.fetch(ctx) if err != nil { - return nil, fmt.Errorf("unable to read specified client key (%s): %w", c.KeyFile, err) + return nil, fmt.Errorf("unable to read specified client key: %w", err) } - } else { - keyData = []byte(c.Key) } - cert, err := tls.X509KeyPair(certData, keyData) + cert, err := tls.X509KeyPair([]byte(certData), []byte(keyData)) if err != nil { - return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", c.CertFile, c.KeyFile, err) + return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", certSecret.description(), keySecret.description(), err) } return &cert, nil } -// readCAFile reads the CA cert file from disk. -func readCAFile(f string) ([]byte, error) { - data, err := os.ReadFile(f) - if err != nil { - return nil, fmt.Errorf("unable to load specified CA cert %s: %w", f, err) - } - return data, nil -} - // updateRootCA parses the given byte slice as a series of PEM encoded certificates and updates tls.Config.RootCAs. func updateRootCA(cfg *tls.Config, b []byte) bool { caCertPool := x509.NewCertPool() @@ -1057,15 +1250,24 @@ type tlsRoundTripper struct { } type TLSRoundTripperSettings struct { - CA, CAFile string - Cert, CertFile string - Key, KeyFile string + CA secret + Cert secret + Key secret } func NewTLSRoundTripper( cfg *tls.Config, settings TLSRoundTripperSettings, newRT func(*tls.Config) (http.RoundTripper, error), +) (http.RoundTripper, error) { + return NewTLSRoundTripperWithContext(context.Background(), cfg, settings, newRT) +} + +func NewTLSRoundTripperWithContext( + ctx context.Context, + cfg *tls.Config, + settings TLSRoundTripperSettings, + newRT func(*tls.Config) (http.RoundTripper, error), ) (http.RoundTripper, error) { t := &tlsRoundTripper{ settings: settings, @@ -1078,7 +1280,7 @@ func NewTLSRoundTripper( return nil, err } t.rt = rt - _, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash() + _, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash(ctx) if err != nil { return nil, err } @@ -1086,38 +1288,31 @@ func NewTLSRoundTripper( return t, nil } -func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, error) { - var ( - caBytes, certBytes, keyBytes []byte - - err error - ) +func (t *tlsRoundTripper) getTLSDataWithHash(ctx context.Context) ([]byte, []byte, []byte, []byte, error) { + var caBytes, certBytes, keyBytes []byte - if t.settings.CAFile != "" { - caBytes, err = os.ReadFile(t.settings.CAFile) + if t.settings.CA != nil { + ca, err := t.settings.CA.fetch(ctx) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, fmt.Errorf("unable to read CA cert: %w", err) } - } else if t.settings.CA != "" { - caBytes = []byte(t.settings.CA) + caBytes = []byte(ca) } - if t.settings.CertFile != "" { - certBytes, err = os.ReadFile(t.settings.CertFile) + if t.settings.Cert != nil { + cert, err := t.settings.Cert.fetch(ctx) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, fmt.Errorf("unable to read client cert: %w", err) } - } else if t.settings.Cert != "" { - certBytes = []byte(t.settings.Cert) + certBytes = []byte(cert) } - if t.settings.KeyFile != "" { - keyBytes, err = os.ReadFile(t.settings.KeyFile) + if t.settings.Key != nil { + key, err := t.settings.Key.fetch(ctx) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, fmt.Errorf("unable to read client key: %w", err) } - } else if t.settings.Key != "" { - keyBytes = []byte(t.settings.Key) + keyBytes = []byte(key) } var caHash, certHash, keyHash [32]byte @@ -1137,7 +1332,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, // RoundTrip implements the http.RoundTrip interface. func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash() + caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash(req.Context()) if err != nil { return nil, err } @@ -1158,7 +1353,7 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // using GetClientCertificate. tlsConfig := t.tlsConfig.Clone() if !updateRootCA(tlsConfig, caData) { - return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CAFile) + return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CA.description()) } rt, err = t.newRT(tlsConfig) if err != nil { diff --git a/config/http_config_test.go b/config/http_config_test.go index 01dc7d6b..14e07c22 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -81,11 +81,11 @@ var invalidHTTPClientConfigs = []struct { }, { httpClientConfigFile: "testdata/http.conf.basic-auth.too-much.bad.yaml", - errMsg: "at most one of basic_auth password & password_file must be configured", + errMsg: "at most one of basic_auth password, password_file & password_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.basic-auth.bad-username.yaml", - errMsg: "at most one of basic_auth username & username_file must be configured", + errMsg: "at most one of basic_auth username, username_file & username_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.mix-bearer-and-creds.bad.yaml", @@ -109,7 +109,7 @@ var invalidHTTPClientConfigs = []struct { }, { httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml", - errMsg: "at most one of oauth2 client_secret & client_secret_file must be configured", + errMsg: "at most one of oauth2 client_secret, client_secret_file & client_secret_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.oauth2-no-client-id.bad.yaml", @@ -507,46 +507,48 @@ func TestNewClientFromConfig(t *testing.T) { } for _, validConfig := range newClientValidConfig { - testServer, err := newTestServer(validConfig.handler) - if err != nil { - t.Fatal(err.Error()) - } - defer testServer.Close() + t.Run("", func(t *testing.T) { + testServer, err := newTestServer(validConfig.handler) + if err != nil { + t.Fatal(err.Error()) + } + defer testServer.Close() - if validConfig.clientConfig.OAuth2 != nil { - // We don't have access to the test server's URL when configuring the test cases, - // so it has to be specified here. - validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token" - } + if validConfig.clientConfig.OAuth2 != nil { + // We don't have access to the test server's URL when configuring the test cases, + // so it has to be specified here. + validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token" + } - err = validConfig.clientConfig.Validate() - if err != nil { - t.Fatal(err.Error()) - } - client, err := NewClientFromConfig(validConfig.clientConfig, "test") - if err != nil { - t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig) - continue - } + err = validConfig.clientConfig.Validate() + if err != nil { + t.Fatal(err.Error()) + } + client, err := NewClientFromConfig(validConfig.clientConfig, "test") + if err != nil { + t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig) + return + } - response, err := client.Get(testServer.URL) - if err != nil { - t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err) - continue - } + response, err := client.Get(testServer.URL) + if err != nil { + t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err) + return + } - message, err := io.ReadAll(response.Body) - response.Body.Close() - if err != nil { - t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig) - continue - } + message, err := io.ReadAll(response.Body) + response.Body.Close() + if err != nil { + t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig) + return + } - trimMessage := strings.TrimSpace(string(message)) - if ExpectedMessage != trimMessage { - t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v", - ExpectedMessage, trimMessage, validConfig.clientConfig) - } + trimMessage := strings.TrimSpace(string(message)) + if ExpectedMessage != trimMessage { + t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v", + ExpectedMessage, trimMessage, validConfig.clientConfig) + } + }) } } @@ -606,7 +608,7 @@ func TestNewClientFromInvalidConfig(t *testing.T) { InsecureSkipVerify: true, }, }, - errorMsg: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA), + errorMsg: fmt.Sprintf("unable to read CA cert: unable to read file %s", MissingCA), }, { clientConfig: HTTPClientConfig{ @@ -615,7 +617,7 @@ func TestNewClientFromInvalidConfig(t *testing.T) { InsecureSkipVerify: true, }, }, - errorMsg: fmt.Sprintf("unable to use specified CA cert %s", InvalidCA), + errorMsg: fmt.Sprintf("unable to use specified CA cert file %s", InvalidCA), }, } @@ -706,7 +708,7 @@ func TestMissingBearerAuthFile(t *testing.T) { t.Fatal("No error is returned here") } - if !strings.Contains(err.Error(), "unable to read authorization credentials file missing/bearer.token: open missing/bearer.token: no such file or directory") { + if !strings.Contains(err.Error(), "unable to read authorization credentials: unable to read file missing/bearer.token: open missing/bearer.token: no such file or directory") { t.Fatal("wrong error message being returned") } } @@ -725,7 +727,7 @@ func TestBearerAuthRoundTripper(t *testing.T) { }, nil, nil) // Normal flow. - bearerAuthRoundTripper := NewAuthorizationCredentialsRoundTripper("Bearer", BearerToken, fakeRoundTripper) + bearerAuthRoundTripper := NewAuthorizationCredentialsRoundTripper("Bearer", &inlineSecret{text: BearerToken}, fakeRoundTripper) request, _ := http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("User-Agent", "Douglas Adams mind") _, err := bearerAuthRoundTripper.RoundTrip(request) @@ -734,7 +736,7 @@ func TestBearerAuthRoundTripper(t *testing.T) { } // Should honor already Authorization header set. - bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsRoundTripper("Bearer", newBearerToken, fakeRoundTripper) + bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsRoundTripper("Bearer", &inlineSecret{text: newBearerToken}, fakeRoundTripper) request, _ = http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("Authorization", ExpectedBearer) _, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request) @@ -753,7 +755,7 @@ func TestBearerAuthFileRoundTripper(t *testing.T) { }, nil, nil) // Normal flow. - bearerAuthRoundTripper := NewAuthorizationCredentialsFileRoundTripper("Bearer", BearerTokenFile, fakeRoundTripper) + bearerAuthRoundTripper := NewAuthorizationCredentialsRoundTripper("Bearer", &fileSecret{file: BearerTokenFile}, fakeRoundTripper) request, _ := http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("User-Agent", "Douglas Adams mind") _, err := bearerAuthRoundTripper.RoundTrip(request) @@ -762,7 +764,7 @@ func TestBearerAuthFileRoundTripper(t *testing.T) { } // Should honor already Authorization header set. - bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsFileRoundTripper("Bearer", MissingBearerTokenFile, fakeRoundTripper) + bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsRoundTripper("Bearer", &fileSecret{file: MissingBearerTokenFile}, fakeRoundTripper) request, _ = http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("Authorization", ExpectedBearer) _, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request) @@ -861,7 +863,7 @@ func TestTLSConfigInvalidCA(t *testing.T) { ServerName: "", InsecureSkipVerify: false, }, - errorMessage: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA), + errorMessage: fmt.Sprintf("unable to read CA cert: unable to read file %s", MissingCA), }, { configTLSConfig: TLSConfig{ @@ -871,7 +873,7 @@ func TestTLSConfigInvalidCA(t *testing.T) { ServerName: "", InsecureSkipVerify: false, }, - errorMessage: fmt.Sprintf("unable to read specified client cert (%s):", MissingCert), + errorMessage: fmt.Sprintf("unable to read specified client cert: unable to read file %s", MissingCert), }, { configTLSConfig: TLSConfig{ @@ -881,7 +883,7 @@ func TestTLSConfigInvalidCA(t *testing.T) { ServerName: "", InsecureSkipVerify: false, }, - errorMessage: fmt.Sprintf("unable to read specified client key (%s):", MissingKey), + errorMessage: fmt.Sprintf("unable to read specified client key: unable to read file %s", MissingKey), }, { configTLSConfig: TLSConfig{ @@ -892,7 +894,7 @@ func TestTLSConfigInvalidCA(t *testing.T) { ServerName: "", InsecureSkipVerify: false, }, - errorMessage: "at most one of cert and cert_file must be configured", + errorMessage: "at most one of cert, cert_file & cert_ref must be configured", }, { configTLSConfig: TLSConfig{ @@ -934,14 +936,11 @@ func TestBasicAuthNoPassword(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if rt.username != "user" { - t.Errorf("Bad HTTP client username: %s", rt.username) - } - if string(rt.password) != "" { - t.Errorf("Expected empty HTTP client password: %s", rt.password) + if username, _ := rt.username.fetch(context.Background()); username != "user" { + t.Errorf("Bad HTTP client username: %s", username) } - if string(rt.passwordFile) != "" { - t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile) + if rt.password != nil { + t.Errorf("Expected empty HTTP client password") } } @@ -960,14 +959,11 @@ func TestBasicAuthNoUsername(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if rt.username != "" { - t.Errorf("Got unexpected username: %s", rt.username) + if rt.username != nil { + t.Errorf("Got unexpected username") } - if string(rt.password) != "secret" { - t.Errorf("Unexpected HTTP client password: %s", string(rt.password)) - } - if string(rt.passwordFile) != "" { - t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile) + if password, _ := rt.password.fetch(context.Background()); password != "secret" { + t.Errorf("Unexpected HTTP client password: %s", password) } } @@ -986,14 +982,81 @@ func TestBasicAuthPasswordFile(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if rt.username != "user" { - t.Errorf("Bad HTTP client username: %s", rt.username) + if username, _ := rt.username.fetch(context.Background()); username != "user" { + t.Errorf("Bad HTTP client username: %s", username) } - if string(rt.password) != "" { - t.Errorf("Bad HTTP client password: %s", rt.password) + if password, _ := rt.password.fetch(context.Background()); password != "foobar" { + t.Errorf("Bad HTTP client password: %s", password) } - if string(rt.passwordFile) != "testdata/basic-auth-password" { - t.Errorf("Bad HTTP client passwordFile: %s", rt.passwordFile) +} + +type secretManager struct { + data map[string]string +} + +func (m *secretManager) Fetch(ctx context.Context, secretRef string) (string, error) { + secretData, ok := m.data[secretRef] + if !ok { + return "", fmt.Errorf("unknown secret %s", secretRef) + } + return secretData, nil +} + +func TestBasicAuthSecretManager(t *testing.T) { + cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.ref.yaml") + if err != nil { + t.Fatalf("Error loading HTTP client config: %v", err) + } + manager := secretManager{ + data: map[string]string{ + "admin": "user", + "pass": "foobar", + }, + } + client, err := NewClientFromConfig(*cfg, "test", WithSecretManager(&manager)) + if err != nil { + t.Fatalf("Error creating HTTP Client: %v", err) + } + + rt, ok := client.Transport.(*basicAuthRoundTripper) + if !ok { + t.Fatalf("Error casting to basic auth transport, %v", client.Transport) + } + + if username, _ := rt.username.fetch(context.Background()); username != "user" { + t.Errorf("Bad HTTP client username: %s", username) + } + if password, _ := rt.password.fetch(context.Background()); password != "foobar" { + t.Errorf("Bad HTTP client password: %s", password) + } +} + +func TestBasicAuthSecretManagerNotFound(t *testing.T) { + cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.ref.yaml") + if err != nil { + t.Fatalf("Error loading HTTP client config: %v", err) + } + manager := secretManager{ + data: map[string]string{ + "admin1": "user", + "foobar": "pass", + }, + } + client, err := NewClientFromConfig(*cfg, "test", WithSecretManager(&manager)) + if err != nil { + t.Fatalf("Error creating HTTP Client: %v", err) + } + + rt, ok := client.Transport.(*basicAuthRoundTripper) + if !ok { + t.Fatalf("Error casting to basic auth transport, %v", client.Transport) + } + + if _, err := rt.username.fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret admin") { + t.Errorf("Unexpected error message: %s", err) + } + if _, err := rt.password.fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret pass") { + t.Errorf("Unexpected error message: %s", err) } } @@ -1012,14 +1075,11 @@ func TestBasicUsernameFile(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if rt.username != "" { - t.Errorf("Bad HTTP client username: %s", rt.username) - } - if string(rt.usernameFile) != "testdata/basic-auth-username" { - t.Errorf("Bad HTTP client usernameFile: %s", rt.usernameFile) + if username, _ := rt.username.fetch(context.Background()); username != "testuser" { + t.Errorf("Bad HTTP client username: %s", username) } - if string(rt.passwordFile) != "testdata/basic-auth-password" { - t.Errorf("Bad HTTP client passwordFile: %s", rt.passwordFile) + if password, _ := rt.password.fetch(context.Background()); password != "foobar" { + t.Errorf("Bad HTTP client passwordFile: %s", password) } } @@ -1248,7 +1308,7 @@ func TestTLSRoundTripper_Inline(t *testing.T) { certText: readFile(t, ClientCertificatePath), keyText: readFile(t, ClientKeyNoPassPath), - errMsg: "unable to use inline CA cert", + errMsg: "unable to use specified CA cert inline", }, { // Invalid cert. @@ -1408,7 +1468,7 @@ func TestTLSRoundTripperRaces(t *testing.T) { func TestHideHTTPClientConfigSecrets(t *testing.T) { c, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") if err != nil { - t.Errorf("Error parsing %s: %s", "testdata/http.conf.good.yml", err) + t.Fatalf("Error parsing %s: %s", "testdata/http.conf.good.yml", err) } // String method must not reveal authentication credentials. @@ -1569,7 +1629,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) + rt := NewOAuth2RoundTripper(&inlineSecret{text: string(expectedConfig.ClientSecret)}, &expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt, @@ -1739,7 +1799,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) + rt := NewOAuth2RoundTripper(&inlineSecret{text: string(expectedConfig.ClientSecret)}, &expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt, diff --git a/config/testdata/http.conf.basic-auth.ref.yaml b/config/testdata/http.conf.basic-auth.ref.yaml new file mode 100644 index 00000000..68a7b3f8 --- /dev/null +++ b/config/testdata/http.conf.basic-auth.ref.yaml @@ -0,0 +1,3 @@ +basic_auth: + username_ref: admin + password_ref: pass \ No newline at end of file diff --git a/sigv4/go.mod b/sigv4/go.mod index 1b8ff1f4..05733979 100644 --- a/sigv4/go.mod +++ b/sigv4/go.mod @@ -5,8 +5,8 @@ go 1.20 replace github.com/prometheus/common => ../ require ( - github.com/aws/aws-sdk-go v1.51.32 - github.com/prometheus/client_golang v1.19.0 + github.com/aws/aws-sdk-go v1.53.14 + github.com/prometheus/client_golang v1.19.1 github.com/prometheus/common v0.48.0 github.com/stretchr/testify v1.9.0 gopkg.in/yaml.v2 v2.4.0 diff --git a/sigv4/go.sum b/sigv4/go.sum index ebdbf745..a407fad8 100644 --- a/sigv4/go.sum +++ b/sigv4/go.sum @@ -1,5 +1,5 @@ -github.com/aws/aws-sdk-go v1.51.32 h1:A6mPui7QP4mwmovyzgtdedbRbNur1Iu0/El7hBWNHms= -github.com/aws/aws-sdk-go v1.51.32/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go v1.53.14 h1:SzhkC2Pzag0iRW8WBb80RzKdGXDydJR9LAMs2GyKJ2M= +github.com/aws/aws-sdk-go v1.53.14/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= @@ -22,8 +22,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= -github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= +github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=