Skip to content

Commit

Permalink
Merge pull request #26 from InVisionApp/getheader
Browse files Browse the repository at this point in the history
get header middleware
  • Loading branch information
caledhwa authored Jul 10, 2017
2 parents 759e085 + da1a1c0 commit 5c1e2eb
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
40 changes: 40 additions & 0 deletions middleware_getheader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package rye

import (
"context"
"net/http"
)

type getHeader struct {
headerName string
contextKey string
}

/*
NewMiddlewareGetHeader creates a new handler to extract any header and save its value into the context.
headerName: the name of the header you want to extract
contextKey: the value key that you would like to store this header under in the context
Example usage:
routes.Handle("/some/route", a.Dependencies.MWHandler.Handle(
[]rye.Handler{
rye.NewMiddlewareGetHeader(headerName, contextKey),
yourHandler,
})).Methods("POST")
*/
func NewMiddlewareGetHeader(headerName, contextKey string) func(rw http.ResponseWriter, req *http.Request) *Response {
h := getHeader{headerName: headerName, contextKey: contextKey}
return h.getHeaderMiddleware
}

func (h *getHeader) getHeaderMiddleware(rw http.ResponseWriter, r *http.Request) *Response {
rID := r.Header.Get(h.headerName)
if rID != "" {
return &Response{
Context: context.WithValue(r.Context(), h.contextKey, rID),
}
}

return nil
}
44 changes: 44 additions & 0 deletions middleware_getheader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package rye

import (
"net/http"
"net/http/httptest"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

var _ = Describe("Get Header Middleware", func() {
var (
request *http.Request
response *httptest.ResponseRecorder
)

BeforeEach(func() {
response = httptest.NewRecorder()
request = &http.Request{
Header: make(map[string][]string, 0),
}
})

Describe("getHeaderMiddleware", func() {
Context("when a valid header is passed", func() {
It("should return context with value", func() {
headerName := "SpecialHeader"
ctxKey := "special"
request.Header.Add(headerName, "secret value")
resp := NewMiddlewareGetHeader(headerName, ctxKey)(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Context).ToNot(BeNil())
Expect(resp.Context.Value(ctxKey)).To(Equal("secret value"))
})
})

Context("when no header is passed", func() {
It("should have no value in context", func() {
resp := NewMiddlewareGetHeader("something", "not there")(response, request)
Expect(resp).To(BeNil())
})
})
})
})

0 comments on commit 5c1e2eb

Please sign in to comment.