diff --git a/internal/netbox/client.go b/internal/netbox/client.go index 2d1eb0be..fb73bc0f 100644 --- a/internal/netbox/client.go +++ b/internal/netbox/client.go @@ -81,10 +81,9 @@ func NewClient(apiURL, apiToken string, opts ...ClientOption) (Client, error) { } c := &client{ - httpClient: retryableHTTPClient(5), - baseURL: strings.TrimSuffix(u.String(), "/"), - token: apiToken, - logger: log.L(), + baseURL: strings.TrimSuffix(u.String(), "/"), + token: apiToken, + logger: log.L(), } for _, opt := range opts { @@ -93,6 +92,8 @@ func NewClient(apiURL, apiToken string, opts ...ClientOption) (Client, error) { } } + c.setRetryableHTTPClient(5) + if c.rateLimiter == nil { c.rateLimiter = rate.NewLimiter(rate.Inf, 1) } @@ -154,10 +155,14 @@ func parseAndValidateURL(apiURL string) (*url.URL, error) { return u, nil } -func retryableHTTPClient(retryMax int) *retryablehttp.Client { +func (c *client) setRetryableHTTPClient(retryMax int) { // add retries on 50X errors retryClient := retryablehttp.NewClient() retryClient.RetryMax = retryMax + if c.logger != nil { + retryClient.Logger = newRetryableHTTPLogger(c.logger) + } + retryClient.CheckRetry = func(ctx context.Context, res *http.Response, err error) (bool, error) { if err == nil { // do not retry non-idempotent requests @@ -169,7 +174,7 @@ func retryableHTTPClient(retryMax int) *retryablehttp.Client { return retryablehttp.DefaultRetryPolicy(ctx, res, err) } - return retryClient + c.httpClient = retryClient } // NOTE: trailing "/" is required for endpoints that work with a single object ID diff --git a/internal/netbox/client_test.go b/internal/netbox/client_test.go index e24a4461..34664722 100644 --- a/internal/netbox/client_test.go +++ b/internal/netbox/client_test.go @@ -21,6 +21,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "go.uber.org/zap" ) func TestParseAndValidateURL(t *testing.T) { @@ -59,7 +61,8 @@ func TestParseAndValidateURL(t *testing.T) { } func TestRetryableHTTPClient(t *testing.T) { - client := retryableHTTPClient(1) + c := &client{logger: zap.L()} + c.setRetryableHTTPClient(1) t.Run("idempotent requests retried", func(t *testing.T) { var numCalls int @@ -69,7 +72,7 @@ func TestRetryableHTTPClient(t *testing.T) { })) defer ts.Close() - client.Get(ts.URL) + c.httpClient.Get(ts.URL) numRetries := numCalls - 1 if numRetries != 1 { @@ -85,7 +88,7 @@ func TestRetryableHTTPClient(t *testing.T) { })) defer ts.Close() - client.Post(ts.URL, "application/json", bytes.NewBufferString(`{}`)) + c.httpClient.Post(ts.URL, "application/json", bytes.NewBufferString(`{}`)) numRetries := numCalls - 1 if numRetries != 0 { diff --git a/internal/netbox/log.go b/internal/netbox/log.go new file mode 100644 index 00000000..6350ab4c --- /dev/null +++ b/internal/netbox/log.go @@ -0,0 +1,45 @@ +package netbox + +import ( + retryablehttp "github.com/hashicorp/go-retryablehttp" + "go.uber.org/zap" +) + +// retryableHTTPLogger is a wrapper for zap logger that implements retyablehttp.LeveledLogger +// interface and therefore can be passed to a retryablehttp client +type retryableHTTPLogger struct { + logger *zap.Logger +} + +func newRetryableHTTPLogger(logger *zap.Logger) retryablehttp.LeveledLogger { + return &retryableHTTPLogger{logger: logger} +} + +func (l *retryableHTTPLogger) Error(msg string, keysAndValues ...interface{}) { + l.logger.Error(msg, fieldsFromKeysAndValues(keysAndValues)...) +} + +func (l *retryableHTTPLogger) Info(msg string, keysAndValues ...interface{}) { + l.logger.Info(msg, fieldsFromKeysAndValues(keysAndValues)...) +} + +func (l *retryableHTTPLogger) Debug(msg string, keysAndValues ...interface{}) { + l.logger.Info(msg, fieldsFromKeysAndValues(keysAndValues)...) +} + +func (l *retryableHTTPLogger) Warn(msg string, keysAndValues ...interface{}) { + l.logger.Info(msg, fieldsFromKeysAndValues(keysAndValues)...) +} + +func fieldsFromKeysAndValues(keysAndValues []interface{}) []zap.Field { + var fields []zap.Field + for i := 1; i < len(keysAndValues); i += 2 { + key := keysAndValues[i-1] + value := keysAndValues[i] + if keyStr, ok := key.(string); ok { + fields = append(fields, zap.Any(keyStr, value)) + } + // ignore malformed key-value pair + } + return fields +} diff --git a/internal/netbox/log_test.go b/internal/netbox/log_test.go new file mode 100644 index 00000000..725985ab --- /dev/null +++ b/internal/netbox/log_test.go @@ -0,0 +1,43 @@ +package netbox + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "go.uber.org/zap" +) + +func TestFieldsFromKeysAndValues(t *testing.T) { + tests := []struct { + name string + keysAndValues []interface{} + expectedFields []zap.Field + }{{ + name: "empty", + }, { + name: "simple string pair", + keysAndValues: []interface{}{"foo", "bar"}, + expectedFields: []zap.Field{zap.Any("foo", "bar")}, + }, { + name: "multiple pairs", + keysAndValues: []interface{}{"foo", 1, "bar", true}, + expectedFields: []zap.Field{zap.Any("foo", 1), zap.Any("bar", true)}, + }, { + name: "key without value", + keysAndValues: []interface{}{"foo", "bar", "baz"}, + expectedFields: []zap.Field{zap.Any("foo", "bar")}, + }, { + name: "key is not a string", + keysAndValues: []interface{}{100, "bar"}, + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + fields := fieldsFromKeysAndValues(test.keysAndValues) + + if diff := cmp.Diff(test.expectedFields, fields); diff != "" { + t.Errorf("\n (-want, +got)\n%s", diff) + } + }) + } +}