diff --git a/middleware_static_filesystem_test.go b/middleware_static_filesystem_test.go index 6d6e812..0c25754 100644 --- a/middleware_static_filesystem_test.go +++ b/middleware_static_filesystem_test.go @@ -31,7 +31,7 @@ var _ = Describe("Static File Middleware", func() { It("should return a response", func() { path = "/static-examples/dist/" - request, _ := http.NewRequest("GET", "/dist/test.html", nil) + request, _ = http.NewRequest("GET", "/dist/test.html", nil) resp := NewStaticFilesystem(testPath+path, "/dist/")(response, request) Expect(resp).To(BeNil()) diff --git a/rye.go b/rye.go index 0339a3d..022649f 100644 --- a/rye.go +++ b/rye.go @@ -12,6 +12,7 @@ import ( "time" //log "github.com/sirupsen/logrus" + "github.com/cactus/go-statsd-client/statsd" ) @@ -24,6 +25,12 @@ type MWHandler struct { beforeHandlers []Handler } +// CustomStatter allows the client to log any additional statsD metrics Rye +// computes around the request handler. +type CustomStatter interface { + ReportStats(handlerName string, elapsedTime time.Duration, req *http.Request, resp *Response) error +} + // Config struct allows you to set a reference to a statsd.Statter and include it's stats rate. type Config struct { Statter statsd.Statter @@ -33,6 +40,9 @@ type Config struct { NoErrStats bool NoDurationStats bool NoStatusCodeStats bool + + // Customer Statter for the client + CustomStatter CustomStatter } // JSONStatus is a simple container used for conveying status messages. @@ -148,6 +158,12 @@ func (m *MWHandler) do(w http.ResponseWriter, r *http.Request, handler Handler) // Record status code metric (default 2xx) go m.reportStatusCode(handlerName, statusCode) } + + // If a CustomStatter is set, send the handler metrics to it. + // This allows the client to handle these metrics however it wants. + if m.Config.CustomStatter != nil && resp != nil { + go m.Config.CustomStatter.ReportStats(handlerName, time.Since(startTime), r, resp) + } }() // stop executing rest of the diff --git a/rye_test.go b/rye_test.go index a0169c2..9761ded 100644 --- a/rye_test.go +++ b/rye_test.go @@ -36,16 +36,33 @@ type statsTiming struct { StatRate float32 } +var reportedStats = make(chan fakeReportedStats) + +type fakeReportedStats struct { + HandlerName string + Duration time.Duration + Request *http.Request + Response *Response +} + +type fakeCustomStatter struct{} + +func (fcs *fakeCustomStatter) ReportStats(handler string, dur time.Duration, req *http.Request, res *Response) error { + reportedStats <- fakeReportedStats{handler, dur, req, res} + return nil +} + var _ = Describe("Rye", func() { var ( - request *http.Request - response *httptest.ResponseRecorder - mwHandler *MWHandler - ryeConfig Config - fakeStatter *statsdfakes.FakeStatter - inc chan statsInc - timing chan statsTiming + request *http.Request + response *httptest.ResponseRecorder + mwHandler *MWHandler + ryeConfig Config + fakeStatter *statsdfakes.FakeStatter + fakeClientStatter *fakeCustomStatter + inc chan statsInc + timing chan statsTiming ) const ( @@ -54,6 +71,7 @@ var _ = Describe("Rye", func() { BeforeEach(func() { fakeStatter = &statsdfakes.FakeStatter{} + fakeClientStatter = &fakeCustomStatter{} ryeConfig = Config{ Statter: fakeStatter, StatRate: STATRATE, @@ -81,7 +99,6 @@ var _ = Describe("Rye", func() { timing <- statsTiming{name, time, statrate} return nil } - }) AfterEach(func() { @@ -329,6 +346,64 @@ var _ = Describe("Rye", func() { Expect(fakeStatter.IncCallCount()).To(Equal(2)) }) }) + + Context("when a custom statter is supplied", func() { + It("should call the ReportStats method", func() { + ryeConfig := Config{ + Statter: fakeStatter, + StatRate: STATRATE, + CustomStatter: fakeClientStatter, + } + + handler := NewMWHandler(ryeConfig) + h := handler.Handle([]Handler{successWithResponse}) + h.ServeHTTP(response, request) + + Expect(h).ToNot(BeNil()) + Expect(h).To(BeAssignableToTypeOf(func(http.ResponseWriter, *http.Request) {})) + + Eventually(inc).Should(Receive(Equal(statsInc{"handlers.successWithResponse.200", 1, float32(STATRATE)}))) + Eventually(timing).Should(Receive(HaveTiming("handlers.successWithResponse.runtime", float32(STATRATE)))) + + var receivedReportedStats fakeReportedStats + var resp *Response + + Eventually(reportedStats).Should(Receive(&receivedReportedStats)) + Expect(receivedReportedStats.HandlerName).To(Equal("successWithResponse")) + Expect(receivedReportedStats.Duration.Seconds()/1000 > 0).To(Equal(true)) + Expect(receivedReportedStats.Request).To(BeAssignableToTypeOf(request)) + Expect(receivedReportedStats.Response).To(BeAssignableToTypeOf(resp)) + Expect(receivedReportedStats.Response.StatusCode).To(Equal(200)) + }) + }) + + Context("when a custom statter is NOT supplied", func() { + It("should not call the ReportStats method", func() { + ryeConfig := Config{ + Statter: fakeStatter, + StatRate: STATRATE, + } + + handler := NewMWHandler(ryeConfig) + h := handler.Handle([]Handler{successWithResponse}) + h.ServeHTTP(response, request) + + Expect(h).ToNot(BeNil()) + Expect(h).To(BeAssignableToTypeOf(func(http.ResponseWriter, *http.Request) {})) + + Eventually(inc).Should(Receive(Equal(statsInc{"handlers.successWithResponse.200", 1, float32(STATRATE)}))) + Eventually(timing).Should(Receive(HaveTiming("handlers.successWithResponse.runtime", float32(STATRATE)))) + + time.Sleep(time.Millisecond * 10) + + var receivedReportedStats fakeReportedStats + + Expect(receivedReportedStats.HandlerName).To(Equal("")) + Expect(receivedReportedStats.Duration.Nanoseconds()).To(Equal(int64(0))) + Expect(receivedReportedStats.Request).To(BeNil()) + Expect(receivedReportedStats.Response).To(BeNil()) + }) + }) }) Describe("getFuncName", func() { @@ -385,6 +460,15 @@ func success2Handler(rw http.ResponseWriter, r *http.Request) *Response { return nil } +func successWithResponse(rw http.ResponseWriter, r *http.Request) *Response { + return &Response{ + StatusCode: 200, + Err: nil, + StopExecution: false, + Context: context.Background(), + } +} + func badResponseHandler(rw http.ResponseWriter, r *http.Request) *Response { return &Response{} }