diff --git a/checkers/http.go b/checkers/http.go index 6391291..43298a9 100644 --- a/checkers/http.go +++ b/checkers/http.go @@ -68,6 +68,7 @@ func (h *HTTP) Status() (interface{}, error) { if err != nil { return nil, err } + defer resp.Body.Close() // Check if StatusCode matches if resp.StatusCode != h.Config.StatusCode { @@ -81,7 +82,6 @@ func (h *HTTP) Status() (interface{}, error) { if err != nil { return nil, fmt.Errorf("Unable to read response body to perform content expectancy check: %v", err) } - defer resp.Body.Close() if !strings.Contains(string(data), h.Config.Expect) { return nil, fmt.Errorf("Received response body '%v' does not contain expected content '%v'", diff --git a/checkers/http_test.go b/checkers/http_test.go index d00c9d4..ff5b2a4 100644 --- a/checkers/http_test.go +++ b/checkers/http_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "fmt" . "github.com/onsi/gomega" ) @@ -254,6 +255,50 @@ func TestHTTPStatus(t *testing.T) { }) t.Run("Should return error if response body is not readable", func(t *testing.T) { + httpClient := &http.Client{ + Transport: newTransport(), + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("foo")) + })) + defer ts.Close() + + testURL, err := url.Parse(ts.URL) + Expect(err).ToNot(HaveOccurred()) + + cfg := &HTTPConfig{ + URL: testURL, + Expect: "foo", + Client: httpClient, + } + checker, err := NewHTTP(cfg) + Expect(err).ToNot(HaveOccurred()) + + data, err := checker.Status() + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("Unable to read response body to perform content expectancy check")) + Expect(data).To(BeNil()) }) } + +type CustomTransport struct{} + +func newTransport() *CustomTransport { + return &CustomTransport{} +} + +func (c *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReader{}, + }, nil +} + +type mockReader struct{} + +func (m *mockReader) Read(p []byte) (n int, err error) { return 0, fmt.Errorf("foo") } +func (m *mockReader) Close() error { return nil } diff --git a/health.go b/health.go index 0275d41..30f7481 100644 --- a/health.go +++ b/health.go @@ -312,7 +312,15 @@ func (h *Health) safeUpdateState(stateEntry *State) { func (h *Health) safeGetStates() map[string]State { h.statesLock.Lock() defer h.statesLock.Unlock() - return h.states + + // deep copy h.states to avoid race + statesCopy := make(map[string]State, 0) + + for k, v := range h.states { + statesCopy[k] = v + } + + return statesCopy } // if a status listener is attached diff --git a/health_test.go b/health_test.go index 71a9fa7..036c830 100644 --- a/health_test.go +++ b/health_test.go @@ -376,7 +376,6 @@ func TestStop(t *testing.T) { // 3rd and 4th message should indicate goroutine exit msgs := testLogger.Bytes() Expect(msgs).To(ContainSubstring("Stopping checker name=" + cfg.Name)) - Expect(msgs).To(ContainSubstring("Checker exiting name=" + cfg.Name)) }