Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix SPNEGO HTTP client redirect loop detection. #531

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions v8/spnego/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ import (

// Client side functionality //

var errRedirectLoop = errors.New("stopped after 10 redirects")

// Client will negotiate authentication with a server using SPNEGO.
type Client struct {
*http.Client
krb5Client *client.Client
spn string
reqs []*http.Request
}

type redirectErr struct {
Expand Down Expand Up @@ -80,6 +81,14 @@ func NewClient(krb5Cl *client.Client, httpCl *http.Client, spn string) *Client {

// Do is the SPNEGO enabled HTTP client's equivalent of the http.Client's Do method.
func (c *Client) Do(req *http.Request) (resp *http.Response, err error) {
return c.do(req, nil)
}

func (c *Client) do(req *http.Request, via []*http.Request) (resp *http.Response, err error) {
if len(via) >= 10 {
return resp, errRedirectLoop
}

var body bytes.Buffer
if req.Body != nil {
// Use a tee reader to capture any body sent in case we have to replay it again
Expand All @@ -93,15 +102,11 @@ func (c *Client) Do(req *http.Request) (resp *http.Response, err error) {
if e, ok := ue.Err.(redirectErr); ok {
// Picked up a redirect
e.reqTarget.Header.Del(HTTPHeaderAuthRequest)
c.reqs = append(c.reqs, e.reqTarget)
if len(c.reqs) >= 10 {
return resp, errors.New("stopped after 10 redirects")
}
if req.Body != nil {
// Refresh the body reader so the body can be sent again
e.reqTarget.Body = io.NopCloser(&body)
}
return c.Do(e.reqTarget)
return c.do(e.reqTarget, append(via, e.reqTarget))
}
}
return resp, err
Expand All @@ -117,7 +122,7 @@ func (c *Client) Do(req *http.Request) (resp *http.Response, err error) {
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
return c.Do(req)
return c.do(req, via)
}
return resp, err
}
Expand Down
99 changes: 99 additions & 0 deletions v8/spnego/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net/http/cookiejar"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"

Expand All @@ -28,6 +29,104 @@ import (
"github.com/stretchr/testify/assert"
)

// fakeTransport is a transport implementation allowing specification of HTTP responses.
type fakeTransport map[string]func() *http.Response

func (ft fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
// Note: This doesn't support requests with multiple parameters, since parameter ordering in the string is nondeterministic.
if resp, ok := ft[r.URL.String()]; ok {
return resp(), nil
}
return nil, fmt.Errorf("unexpected url: %s", r.URL.String())
}

func httpResponse(statusCode int, body string, extraHeaders map[string]string) func() *http.Response {
return func() *http.Response {
resp := &http.Response{
StatusCode: statusCode,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(body)),
}
for k, v := range extraHeaders {
resp.Header.Add(k, v)
}
return resp
}
}

func TestClient_Do(t *testing.T) {
hc := &http.Client{
Transport: fakeTransport{
"http://example.com/page1": httpResponse(http.StatusOK, "page 1 ok", nil),
"http://example.com/redirect": httpResponse(http.StatusFound, "shouldn't see this", map[string]string{"Location": "http://example.com/page1"}),
"http://example.com/loop": httpResponse(http.StatusFound, "down the spiral you go", map[string]string{"Location": "http://example.com/loop"}),
},
}

tests := []struct {
desc string
target string
wantStatusCode int
wantBody string
wantErr error
}{
{
desc: "single page",
target: "http://example.com/page1",
wantStatusCode: http.StatusOK,
wantBody: "page 1 ok",
},
{
desc: "single redirect hop",
target: "http://example.com/redirect",
wantStatusCode: http.StatusOK,
wantBody: "page 1 ok",
},
{
desc: "redirect loop",
target: "http://example.com/loop",
wantErr: errRedirectLoop,
},
}

for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
// So long as we don't return any 401s, we can safely run this without needing to fall through to Kerberos auth.
spnClient := NewClient(nil, hc, "")

// Run the request multiple times, to ensure our loop detection is properly counting redirects on a per-request basis.
for i := 0; i < 20; i++ {
req, err := http.NewRequest("GET", tc.target, nil)
if err != nil {
t.Fatalf("Failed to create GET request for target %s: %v", tc.target, err)
}

resp, err := spnClient.Do(req)
if err == nil {
defer resp.Body.Close()
}
if !errors.Is(err, tc.wantErr) {
t.Fatalf("Unexpected error result while attempting GET request for target %s - got: %v, want: %v", tc.target, err, tc.wantErr)
}

if err == nil {
if resp.StatusCode != tc.wantStatusCode {
t.Errorf("Unexpected status code - got: %d, want: %d", resp.StatusCode, tc.wantStatusCode)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body for target %s: %v", tc.target, err)
}
body := string(b)
if body != tc.wantBody {
t.Errorf("Unexpected response body - got: %s, want: %s", body, tc.wantBody)
}
}
}
})
}
}

func TestClient_SetSPNEGOHeader(t *testing.T) {
test.Integration(t)
b, _ := hex.DecodeString(testdata.KEYTAB_TESTUSER1_TEST_GOKRB5)
Expand Down