Skip to content

Commit

Permalink
checked that v is a pointer or nil before doing the request
Browse files Browse the repository at this point in the history
  • Loading branch information
vtopc committed Feb 18, 2022
1 parent 4a802af commit 78e13dd
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
14 changes: 11 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"reflect"
"strings"

"github.com/vtopc/go-rest/defaults"
Expand Down Expand Up @@ -35,11 +36,18 @@ func NewClient(client *http.Client) *Client {

// Do executes HTTP request.
//
// Stores the result in the value pointed to by v. If v is nil or not a pointer,
// Do returns an InvalidUnmarshalError.
// Stores the result in the value pointed to by v. If v is not a nil and not a pointer,
// Do returns a json.InvalidUnmarshalError.
// Use func `http.NewRequestWithContext` to create `req`.
func (c *Client) Do(req *http.Request, v interface{}, expectedStatusCodes ...int) error {
// TODO: check that `v` is a pointer or nil
// check that `v` is a pointer or nil before doing the request.
if v != nil {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return &json.InvalidUnmarshalError{Type: reflect.TypeOf(v)}
}
}

if req == nil {
return errors.New("empty request")
}
Expand Down
15 changes: 15 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func TestClientDo(t *testing.T) {
respBody []byte
v interface{}
want interface{}
wantErr error
wantWrappedErr error
}{
"positive_get": {
Expand Down Expand Up @@ -65,6 +66,16 @@ func TestClientDo(t *testing.T) {
wantWrappedErr: errors.New("wrong status code (500 not in [200]): {\"error\":\"some error\"}"),
},

"negative_not_a_pointer": {
method: http.MethodGet,
urlPostfix: "/health",
statusCode: http.StatusOK,
expectedStatusCode: http.StatusOK,
respBody: []byte(`{"status":"ok"}`),
v: Struct{},
wantErr: errors.New("json: Unmarshal(non-pointer rest.Struct)"),
},

// TODO: add more test cases
}

Expand Down Expand Up @@ -94,6 +105,10 @@ func TestClientDo(t *testing.T) {

// test:
err = c.Do(req, tt.v, tt.expectedStatusCode)
if tt.wantErr != nil {
require.EqualError(t, err, tt.wantErr.Error())
return
}
if tt.wantWrappedErr != nil {
require.EqualError(t, errors.Unwrap(err), tt.wantWrappedErr.Error())
return
Expand Down
16 changes: 8 additions & 8 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestStatusCodeFromAPIError(t *testing.T) {

v interface{} // for .Do(...)

wantErr error
wantErr bool
wantStatusCode int
}{
"no_errors": {
Expand All @@ -62,21 +62,21 @@ func TestStatusCodeFromAPIError(t *testing.T) {
statusCode: 400,
expectedStatusCode: 201,
body: []byte(`{"errors":[{"message":"test error"}]}`),
wantErr: errors.New("wrong status code (400 not in [201]): {\"errors\":[{\"message\":\"test error\"}]}"),
wantErr: true,
wantStatusCode: 400,
},
"not_found": {
statusCode: 404,
expectedStatusCode: 200,
body: []byte(`{"errors":[{"message":"the entity not found"}]}`),
wantErr: errors.New("wrong status code (404 not in [200]): {\"errors\":[{\"message\":\"the entity not found\"}]}"),
wantErr: true,
wantStatusCode: 404,
},
"not_APIError": {
statusCode: 200,
expectedStatusCode: 200,
body: []byte(`{"foo":"bar"}`),
wantErr: errors.New("failed to unmarshal the response body: json: Unmarshal(non-pointer chan struct {})"),
wantErr: true,
v: make(chan struct{}), // a channel just to fail Unmarshal
wantStatusCode: 500,
},
Expand All @@ -85,7 +85,7 @@ func TestStatusCodeFromAPIError(t *testing.T) {
expectedStatusCode: 200,
body: []byte(`{"foo":"bar"}`),
v: new(S),
wantErr: errors.New(`wrong status code (201 not in [200]): {"foo":"bar"}`),
wantErr: true,
wantStatusCode: 500,
},
}
Expand All @@ -109,10 +109,10 @@ func TestStatusCodeFromAPIError(t *testing.T) {

err = client.Do(req, tt.v, tt.expectedStatusCode)
t.Logf("got error: %v", err)
if tt.wantErr == nil {
assert.NoError(t, err)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.EqualError(t, errors.Unwrap(err), tt.wantErr.Error())
assert.NoError(t, err)
}

// got := StatusCodeFromAPIError(err)
Expand Down

0 comments on commit 78e13dd

Please sign in to comment.