Skip to content

Commit

Permalink
Merge pull request #24 from InVisionApp/qparamtoken
Browse files Browse the repository at this point in the history
query param access token
  • Loading branch information
talpert authored Mar 22, 2017
2 parents e4ee692 + 139d65e commit 759e085
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 18 deletions.
57 changes: 50 additions & 7 deletions middleware_accesstoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ import (
)

type accessTokens struct {
headerName string
tokens []string
paramName string
tokens []string
getFunc func(string, *http.Request) string
missingMessage string
}

/*
NewMiddlewareAccessToken creates a new handler to verify access tokens in a rye chain.
NewMiddlewareAccessToken creates a new handler to verify access tokens passed as a header.
Example usage:
Expand All @@ -23,19 +25,60 @@ Example usage:
})).Methods("POST")
*/
func NewMiddlewareAccessToken(headerName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response {
return newAccessTokenHandler(headerName, tokens, "header")
}

/*
NewMiddlewareAccessQueryToken creates a new handler to verify access tokens passed as a query parameter.
Example usage:
routes.Handle("/some/route", a.Dependencies.MWHandler.Handle(
[]rye.Handler{
rye.NewMiddlewareAccessQueryToken(queryParamName, []string{token1, token2}),
yourHandler,
})).Methods("POST")
*/
func NewMiddlewareAccessQueryToken(queryParamName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response {
return newAccessTokenHandler(queryParamName, tokens, "query")
}

func newAccessTokenHandler(name string, tokens []string, tokenType string) func(rw http.ResponseWriter, req *http.Request) *Response {
a := &accessTokens{
headerName: headerName,
tokens: tokens,
paramName: name,
tokens: tokens,
}

switch tokenType {

case "query":
a.getFunc = func(s string, r *http.Request) string {
q, ok := r.URL.Query()[s]
if !ok {
return ""
}

return q[0]
}
a.missingMessage = fmt.Sprintf("No access token found; ensure you pass the '%s' parameter", name)

default:
// default to using the header
a.getFunc = func(s string, r *http.Request) string {
return r.Header.Get(s)
}
a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name)
}

return a.handle
}

func (a *accessTokens) handle(rw http.ResponseWriter, r *http.Request) *Response {
token := r.Header.Get(a.headerName)
token := a.getFunc(a.paramName, r)

if token == "" {
return &Response{
Err: fmt.Errorf("No access token found; ensure you pass '%s' in header", a.headerName),
Err: errors.New(a.missingMessage),
StatusCode: http.StatusUnauthorized,
}
}
Expand Down
138 changes: 127 additions & 11 deletions middleware_accesstoken_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package rye

import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
Expand All @@ -14,39 +16,48 @@ var _ = Describe("AccessToken Middleware", func() {
request *http.Request
response *httptest.ResponseRecorder

tokenHeaderName = "at-hname"
token1, token2 string
testHandler func(http.ResponseWriter, *http.Request) *Response

token1, token2 string
)

BeforeEach(func() {
response = httptest.NewRecorder()
request = &http.Request{
Header: map[string][]string{},
}

token1 = "test1"
token2 = "test2"
})

Describe("handle", func() {
Context("header token", func() {
var (
tokenHeaderName = "at-hname"
)

BeforeEach(func() {
testHandler = NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})
request = &http.Request{
Header: map[string][]string{},
}
})

Context("when a valid token is used", func() {
It("should return nil", func() {
request.Header.Add(tokenHeaderName, token1)
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).To(BeNil())
})

It("should return nil", func() {
request.Header.Add(tokenHeaderName, token2)
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).To(BeNil())
})
})

Context("when an invalid token is used", func() {
It("should return an error", func() {
request.Header.Add(tokenHeaderName, "blah")
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("invalid access token"))
Expand All @@ -56,7 +67,7 @@ var _ = Describe("AccessToken Middleware", func() {

Context("when no token header exists", func() {
It("should return an error", func() {
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("No access token found"))
Expand All @@ -67,12 +78,117 @@ var _ = Describe("AccessToken Middleware", func() {
Context("when token header is blank", func() {
It("should return an error", func() {
request.Header.Add(tokenHeaderName, "")
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("No access token found"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})
})

Context("query param token", func() {
var (
qParamName string
qParams string
)

BeforeEach(func() {
qParamName = "token"
testHandler = NewMiddlewareAccessQueryToken(qParamName, []string{token1, token2})
})

JustBeforeEach(func() {
u, err := url.Parse(fmt.Sprintf("http://doesntmatter.io/blah?%s", qParams))
Expect(err).ToNot(HaveOccurred())

request = &http.Request{
URL: u,
}
})

Context("when a valid token is used", func() {
BeforeEach(func() {
qParams = fmt.Sprintf("%s=%s", qParamName, token1)
})

It("should return nil", func() {
resp := testHandler(response, request)
Expect(resp).To(BeNil())
})
})

Context("when the other valid token is used", func() {
BeforeEach(func() {
qParams = fmt.Sprintf("%s=%s", qParamName, token2)
})

It("should return nil", func() {
resp := testHandler(response, request)
Expect(resp).To(BeNil())
})
})

Context("when an invalid token is used", func() {
BeforeEach(func() {
qParams = fmt.Sprintf("%s=blah", qParamName)
})

It("should return an error", func() {
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("invalid access token"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})

Context("when no token param exists", func() {
BeforeEach(func() {
qParams = "something=else"
})

It("should return an error", func() {
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("No access token found"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})

Context("when token param is blank", func() {
BeforeEach(func() {
qParams = fmt.Sprintf("%s=''", qParamName)
})

It("should return an error", func() {
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("invalid access token"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})

Context("when no query params", func() {
JustBeforeEach(func() {
u, err := url.Parse("http://doesntmatter.io/blah")
Expect(err).ToNot(HaveOccurred())

request = &http.Request{
URL: u,
}
})

It("should return an error", func() {
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("No access token found"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})

})
})

0 comments on commit 759e085

Please sign in to comment.