Skip to content

Commit

Permalink
MEN-1972 Failover Mender Server
Browse files Browse the repository at this point in the history
Simplified the failover mechanism from last commit. Instead of keep using
server from last succesful request; always cycle the server list on each
request.

Changelog: None

Signed-off-by: Alf-Rune Siqveland <[email protected]>
  • Loading branch information
alfrunes authored and Kristian Amlie committed Sep 28, 2018
1 parent e60f578 commit db1d6f0
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 199 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ _testmain.go
*.prof

mender

# Go test coverage output
*coverage*.txt
78 changes: 53 additions & 25 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ type ApiRequester interface {
Do(req *http.Request) (*http.Response, error)
}

// MenderServer is a placeholder for a full server definition used when
// multiple servers are given. The fields corresponds to the definitions
// given in menderConfig.
type MenderServer struct {
ServerURL string
// TODO: Move all possible server specific configurations in
// menderConfig over to this struct. (e.g. TenantToken?)
}

// APIError is an error type returned after receiving an error message from the
// server. It wraps a regular error with the request_id - and if
// the server returns an error message, this is also returned.
Expand Down Expand Up @@ -118,18 +127,18 @@ type ApiClient struct {
}

// function type for reauthorization closure (see func [email protected])
type ClientReauthorizeFunc func() (AuthToken, error)
type ClientReauthorizeFunc func(string) (AuthToken, error)

// function type for setting server (in case of multiple fallover servers)
type ServerManagementFunc func() string
type ServerManagementFunc func() *MenderServer

// Return a new ApiRequest
func (a *ApiClient) Request(code AuthToken, _nextServerIterator ServerManagementFunc, req ClientReauthorizeFunc) *ApiRequest {
func (a *ApiClient) Request(code AuthToken, nextServerIterator ServerManagementFunc, reauth ClientReauthorizeFunc) *ApiRequest {
return &ApiRequest{
api: a,
auth: code,
nextServerIterator: _nextServerIterator,
revoke: req,
nextServerIterator: nextServerIterator,
revoke: reauth,
}
}

Expand All @@ -148,13 +157,13 @@ type ApiRequest struct {

// tryDo is a wrapper around http.Do that also tries to reauthorize
// on a 401 response (Unauthorized).
func (ar *ApiRequest) tryDo(req *http.Request) (*http.Response, error) {
func (ar *ApiRequest) tryDo(req *http.Request, serverURL string) (*http.Response, error) {
r, err := ar.api.Do(req)
if r != nil && r.StatusCode == http.StatusUnauthorized {
if err == nil && r.StatusCode == http.StatusUnauthorized {
// invalid JWT; most likely the token is expired:
// Try to refresh it and reattempt sending the request
log.Info("Device unauthorized; attempting reauthorization")
if jwt, e := ar.revoke(); e == nil {
if jwt, e := ar.revoke(serverURL); e == nil {
// retry API request with new JWT token
ar.auth = jwt
// check if request had a body
Expand All @@ -175,29 +184,48 @@ func (ar *ApiRequest) tryDo(req *http.Request) (*http.Response, error) {

// Do is a wrapper for http.Do function for ApiRequests. This function in
// addition to calling http.Do handles client-server authorization header /
// reauthorization, as well as attempting failover servers (if given) if
// reauthorization, as well as attempting failover servers (if given) whenever
// the server "refuse" to serve the request.
func (ar *ApiRequest) Do(req *http.Request) (*http.Response, error) {
if ar.nextServerIterator == nil {
return nil, errors.New("Empty server list!")
}
if req.Header.Get("Authorization") == "" {
// Add JWT to header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ar.auth))
}
r, err := ar.tryDo(req)
// if: (failoverservers given) && (error / bad statuscode)
// then: cycle servers and retry request
if (ar.nextServerIterator != nil) &&
(r != nil && r.StatusCode >= 400 && r.StatusCode < 600) || (err != nil) {
// incrementally try listed servers
log.Warnf("Server %q does not serve API-Request: %q. %s",
req.URL.Host, req.URL.Path, "Attempting failover servers.")
// do {try next server} while (server refuse to serve)
for serverURL := ar.nextServerIterator(); serverURL != ""; serverURL = ar.nextServerIterator() {
// set new host (server) in request
req.URL.Host = serverURL
req.Host = serverURL
r, err = ar.tryDo(req)
if r != nil && (r.StatusCode < 400 || r.StatusCode >= 600) {
log.Infof("Automatically fell over to server: %q", serverURL)
var r *http.Response
var host string
var err error

server := ar.nextServerIterator()
for {
// Split host from URL
tmp := strings.Split(server.ServerURL, "://")
if len(tmp) == 1 {
host = tmp[0]
} else {
// (len >= 2) should usually be 2
host = tmp[1]
}

req.URL.Host = host
req.Host = host
r, err = ar.tryDo(req, server.ServerURL)
if err == nil && r.StatusCode < 400 {
break
}
prewHost := server.ServerURL
if server = ar.nextServerIterator(); server == nil {
break
}
log.Warnf("Server %q failed to serve request %q. Attempting %q",
prewHost, req.URL.Path, server.ServerURL)
}
if server != nil {
// reset server iterator
for {
if ar.nextServerIterator() == nil {
break
}
}
Expand Down
111 changes: 103 additions & 8 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,23 @@ import (
"github.com/stretchr/testify/require"
)

func dummy_reauthfunc() (AuthToken, error) {
return AuthToken(""), errors.New("")
func dummy_reauthfunc(str string) (AuthToken, error) {
return AuthToken("dummy"), nil
}

func dummy_srvMngmntFunc() string {
return ""
func dummy_srvMngmntFunc(url string) func() *MenderServer {
// mimic single server callback
srv := MenderServer{ServerURL: url}
called := false
return func() *MenderServer {
if called {
called = false
return nil
} else {
called = true
return &srv
}
}
}

func TestHttpClient(t *testing.T) {
Expand All @@ -63,9 +74,6 @@ func TestApiClientRequest(t *testing.T) {
)
assert.NotNil(t, cl)

req := cl.Request("foobar", dummy_srvMngmntFunc, dummy_reauthfunc)
assert.NotNil(t, req)

responder := &struct {
httpStatus int
headers http.Header
Expand All @@ -81,6 +89,19 @@ func TestApiClientRequest(t *testing.T) {
}))
defer ts.Close()

auth := false
req := cl.Request("foobar", dummy_srvMngmntFunc(ts.URL),
func(url string) (AuthToken, error) {
if !auth {
return AuthToken(""), errors.New("")
} else {
// reset httpstatus
responder.httpStatus = http.StatusOK
return AuthToken("dummy"), nil
}
}) /* cl.Request */
assert.NotNil(t, req)

hreq, _ := http.NewRequest(http.MethodGet, ts.URL, nil)

// ApiRequest should append Authorization header
Expand All @@ -97,6 +118,20 @@ func TestApiClientRequest(t *testing.T) {
assert.NotNil(t, rsp)
assert.NotNil(t, responder.headers)
assert.Equal(t, "Bearer zed", responder.headers.Get("Authorization"))

// should attempt reauthorization and fail
responder.httpStatus = http.StatusUnauthorized
rsp, err = req.Do(hreq)
assert.NoError(t, err)
assert.NotNil(t, rsp)
assert.Equal(t, rsp.StatusCode, http.StatusUnauthorized)

// successful reauthorization
auth = true
rsp, err = req.Do(hreq)
assert.NoError(t, err)
assert.NotNil(t, rsp)
assert.Equal(t, rsp.StatusCode, http.StatusOK)
}

func TestClientConnectionTimeout(t *testing.T) {
Expand All @@ -120,7 +155,7 @@ func TestClientConnectionTimeout(t *testing.T) {
assert.NotNil(t, cl)
assert.NoError(t, err)

req := cl.Request("foobar", dummy_srvMngmntFunc, dummy_reauthfunc)
req := cl.Request("foobar", dummy_srvMngmntFunc(ts.URL), dummy_reauthfunc)
assert.NotNil(t, req)

hreq, err := http.NewRequest(http.MethodGet, ts.URL, nil)
Expand Down Expand Up @@ -282,3 +317,63 @@ func TestUnMarshalErrorMessage(t *testing.T) {
expected := "failed to decode device group data: JSON payload is empty"
assert.Equal(t, expected, unmarshalErrorMessage(bytes.NewReader([]byte(jsonErrMsg))))
}

// Covers some special corner cases of the failover mechanism that is unique.
// In particular this test uses a list of two server where as one of them are
// fake so as to trigger a "failover" to the second server in the list.
// In addition it also covers the case with a 'nil' ServerManagementFunc.
func TestFailoverAPICall(t *testing.T) {
cl, err := NewApiClient(
Config{"server.crt", true, false},
)
assert.NotNil(t, cl)

responder := &struct {
httpStatus int
headers http.Header
}{
http.StatusOK,
http.Header{},
}

ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
responder.headers = r.Header
w.WriteHeader(responder.httpStatus)
w.Header().Set("Content-Type", "application/json")
}))
defer ts.Close()

mulServerfunc := func() func() *MenderServer {
// mimic multiple servers callback where the first one is a faker
srvrs := []MenderServer{MenderServer{ServerURL: "fakeURL.404"},
MenderServer{ServerURL: ts.URL}}
idx := 0
return func() *MenderServer {
var ret *MenderServer
if idx < len(srvrs) {
ret = &srvrs[idx]
idx++
} else {
ret = nil
}
return ret
}
}
req := cl.Request("foobar", mulServerfunc(), dummy_reauthfunc) /* cl.Request */
assert.NotNil(t, req)

hreq, _ := http.NewRequest(http.MethodGet, ts.URL, nil)

// ApiRequest should append Authorization header
rsp, err := req.Do(hreq)
assert.Nil(t, err)
assert.NotNil(t, rsp)
assert.NotNil(t, responder.headers)
assert.Equal(t, "Bearer foobar", responder.headers.Get("Authorization"))

req = cl.Request("foobar", nil, dummy_reauthfunc) /* cl.Request */
assert.NotNil(t, req)

rsp, err = req.Do(hreq)
assert.Error(t, err)
}
50 changes: 23 additions & 27 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@ import (
"github.com/pkg/errors"
)

// menderServer is a placeholder for a full server definition used when
// multiple servers are given. The fields corresponds to the definitions
// given in menderConfig.
type menderServer struct {
ServerURL string
TenantToken string
}

type menderConfig struct {
// ClientProtocol "https"
ClientProtocol string
Expand Down Expand Up @@ -62,14 +54,14 @@ type menderConfig struct {

// Path to server SSL certificate
ServerCertificate string
// (Active) Server URL
// Server URL (For single server conf)
ServerURL string
// Path to deployment log file
UpdateLogPath string
// Server JWT TenantToken
TenantToken string
// List of available servers, to which client can fall over
Servers []menderServer
Servers []client.MenderServer
}

// LoadConfig parses the mender configuration json-file (/etc/mender/mender.conf)
Expand All @@ -86,25 +78,29 @@ func LoadConfig(configFile string) (*menderConfig, error) {
return nil, err
}

if confFromFile.Servers != nil {
if confFromFile.ServerURL != "" || confFromFile.TenantToken != "" {
log.Error("In mender.conf: don't specify both Servers field AND the corresponding fields in base structure (i.e. ServerURL etc.). The first server on the list on the list overwrites these fields.")
return nil, errors.New("Both Servers AND ServerURL / TenantToken given in mender.conf")
if confFromFile.Servers == nil {
if confFromFile.ServerURL == "" {
log.Warn("No server URL(s) specified in mender configuration.")
}
for i := 0; i < len(confFromFile.Servers); i++ {
// Trim possible '/' suffix, which is added back in URL path
if strings.HasSuffix(confFromFile.Servers[i].ServerURL, "/") {
confFromFile.Servers[i].ServerURL =
strings.TrimSuffix(confFromFile.Servers[i].ServerURL, "/")
}
}
// Overwrite "active" server with first one from the list of servers.
confFromFile.ServerURL = confFromFile.Servers[0].ServerURL
confFromFile.TenantToken = confFromFile.Servers[0].TenantToken
} else {
confFromFile.Servers = make([]client.MenderServer, 1)
confFromFile.Servers[0].ServerURL = confFromFile.ServerURL
} else if confFromFile.ServerURL != "" {
log.Error("In mender.conf: don't specify both Servers field " +
"AND the corresponding fields in base structure (i.e. " +
"ServerURL). The first server on the list on the" +
"list overwrites these fields.")
return nil, errors.New("Both Servers AND ServerURL given in " +
"mender.conf")
}
for i := 0; i < len(confFromFile.Servers); i++ {
// Trim possible '/' suffix, which is added back in URL path
if strings.HasSuffix(confFromFile.ServerURL, "/") {
confFromFile.ServerURL = strings.TrimSuffix(confFromFile.ServerURL, "/")
if strings.HasSuffix(confFromFile.Servers[i].ServerURL, "/") {
confFromFile.Servers[i].ServerURL =
strings.TrimSuffix(
confFromFile.Servers[i].ServerURL, "/")
}
if confFromFile.Servers[i].ServerURL == "" {
log.Warnf("Server entry %d has no associated server URL.")
}
}

Expand Down
Loading

0 comments on commit db1d6f0

Please sign in to comment.