diff --git a/pkg/db/aggregate.go b/pkg/db/aggregate.go index 6718046f..2335ef72 100644 --- a/pkg/db/aggregate.go +++ b/pkg/db/aggregate.go @@ -37,7 +37,8 @@ func (q *Query) Count(opts ...CountOptions) (int, error) { headers["Prefer"] = "count=" + countVal headers["Range-Unit"] = "items" - _, resp, err := PostgrestRequest(q.Context, fasthttp.MethodHead, url, nil, headers) + var a interface{} + resp, err := PostgrestRequestBind(q.Context, fasthttp.MethodHead, url, nil, headers, q.ByPass, &a) if err != nil { return 0, err } diff --git a/pkg/db/db.go b/pkg/db/db.go index e728df17..72f69b86 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -21,6 +21,7 @@ type Query struct { LimitValue int OffsetValue int Errors []error + ByPass bool } type ModelBase struct { @@ -34,6 +35,7 @@ func (q *Query) HasError() bool { func NewQuery(ctx raiden.Context) *Query { return &Query{ Context: ctx, + ByPass: false, } } @@ -47,6 +49,11 @@ func (q *Query) From(m interface{}) *Query { return q } +func (q *Query) AsSystem() *Query { + q.ByPass = true + return q +} + func GetTable(m interface{}) string { t := reflect.TypeOf(m) @@ -69,7 +76,7 @@ func (m *ModelBase) Execute() (model *ModelBase) { return m } -func (q Query) Get() ([]byte, error) { +func (q Query) Get(collection interface{}) error { url := q.GetUrl() @@ -77,27 +84,27 @@ func (q Query) Get() ([]byte, error) { headers["Content-Type"] = "application/json" headers["Prefer"] = "return=representation" - resp, _, err := PostgrestRequest(q.Context, fasthttp.MethodGet, url, nil, headers) + _, err := PostgrestRequestBind(q.Context, fasthttp.MethodGet, url, nil, headers, q.ByPass, collection) if err != nil { - return nil, err + return err } - return resp, nil + return nil } -func (q Query) Single() ([]byte, error) { +func (q Query) Single(model interface{}) error { url := q.Limit(1).GetUrl() headers := make(map[string]string) headers["Accept"] = "application/vnd.pgrst.object+json" - res, _, err := PostgrestRequest(q.Context, "GET", url, nil, headers) + _, err := PostgrestRequestBind(q.Context, "GET", url, nil, headers, q.ByPass, model) if err != nil { - return nil, err + return err } - return res, nil + return nil } func (q Query) GetUrl() string { diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go index 9b98e01d..ff130b07 100644 --- a/pkg/db/db_test.go +++ b/pkg/db/db_test.go @@ -112,13 +112,22 @@ func TestGetTable(t *testing.T) { } func TestSingle(t *testing.T) { - _, err := NewQuery(&mockRaidenContext).Model(articleMockModel).Single() + var articleMockModel = ArticleMockModel{} + err := NewQuery(&mockRaidenContext).Model(articleMockModel).Single(&articleMockModel) assert.NoError(t, err) } func TestGet(t *testing.T) { - _, err := NewQuery(&mockRaidenContext).Model(articleMockModel).Get() + var collection interface{} + err := NewQuery(&mockRaidenContext).Model(articleMockModel).Get(&collection) + + assert.NoError(t, err) +} + +func TestByPass(t *testing.T) { + var collection interface{} + err := NewQuery(&mockRaidenContext).Model(articleMockModel).AsSystem().Get(&collection) assert.NoError(t, err) } diff --git a/pkg/db/delete.go b/pkg/db/delete.go index be198aa7..1b65b48c 100644 --- a/pkg/db/delete.go +++ b/pkg/db/delete.go @@ -4,17 +4,18 @@ import ( "github.com/valyala/fasthttp" ) -func (q *Query) Delete() ([]byte, error) { +func (q *Query) Delete() error { url := q.GetUrl() headers := make(map[string]string) headers["Content-Type"] = "application/json" headers["Prefer"] = "return=representation" - resp, _, err := PostgrestRequest(q.Context, fasthttp.MethodDelete, url, nil, headers) + var a interface{} + _, err := PostgrestRequestBind(q.Context, fasthttp.MethodDelete, url, nil, headers, q.ByPass, &a) if err != nil { - return nil, err + return err } - return resp, nil + return nil } diff --git a/pkg/db/delete_test.go b/pkg/db/delete_test.go index 55c8833c..a79f9584 100644 --- a/pkg/db/delete_test.go +++ b/pkg/db/delete_test.go @@ -7,7 +7,7 @@ import ( ) func TestDelete(t *testing.T) { - _, err := NewQuery(&mockRaidenContext).Model(articleMockModel).Delete() + err := NewQuery(&mockRaidenContext).Model(articleMockModel).Delete() assert.NoError(t, err) } diff --git a/pkg/db/insert.go b/pkg/db/insert.go index 24a29671..2843afe8 100644 --- a/pkg/db/insert.go +++ b/pkg/db/insert.go @@ -6,10 +6,10 @@ import ( "github.com/valyala/fasthttp" ) -func (q *Query) Insert(payload interface{}) ([]byte, error) { +func (q *Query) Insert(payload interface{}, model interface{}) error { jsonData, err := json.Marshal(payload) if err != nil { - return nil, err + return err } url := q.GetUrl() @@ -18,10 +18,10 @@ func (q *Query) Insert(payload interface{}) ([]byte, error) { headers["Content-Type"] = "application/json" headers["Prefer"] = "return=representation" - body, _, err := PostgrestRequest(q.Context, fasthttp.MethodPost, url, jsonData, headers) - if err != nil { - return nil, err + _, err0 := PostgrestRequestBind(q.Context, fasthttp.MethodPost, url, jsonData, headers, q.ByPass, model) + if err0 != nil { + return err0 } - return body, nil + return nil } diff --git a/pkg/db/insert_test.go b/pkg/db/insert_test.go index 895e9acf..840c9c64 100644 --- a/pkg/db/insert_test.go +++ b/pkg/db/insert_test.go @@ -16,9 +16,9 @@ func TestInsert(t *testing.T) { CreatedAt: time.Now(), } - _, err := NewQuery(&mockRaidenContext). + err := NewQuery(&mockRaidenContext). Model(articleMockModel). - Insert(article) + Insert(article, &articleMockModel) assert.NoError(t, err) } diff --git a/pkg/db/postgrest_request.go b/pkg/db/postgrest_request.go index bfc57739..a22c3a98 100644 --- a/pkg/db/postgrest_request.go +++ b/pkg/db/postgrest_request.go @@ -1,6 +1,7 @@ package db import ( + "encoding/json" "flag" "fmt" "strings" @@ -9,10 +10,10 @@ import ( "github.com/valyala/fasthttp" ) -func PostgrestRequest(ctx raiden.Context, method string, url string, payload []byte, headers map[string]string) ([]byte, *fasthttp.Response, error) { +func PostgrestRequestBind(ctx raiden.Context, method string, url string, payload []byte, headers map[string]string, bypass bool, result interface{}) (*fasthttp.Response, error) { if !isAllowedMethod(method) { - return nil, nil, fmt.Errorf("method %s is not allowed", method) + return nil, fmt.Errorf("method %s is not allowed", method) } client := &fasthttp.Client{} @@ -38,15 +39,20 @@ func PostgrestRequest(ctx raiden.Context, method string, url string, payload []b req.SetRequestURI(url) req.Header.SetMethod(method) - apikey := string(ctx.RequestContext().Request.Header.Peek("apikey")) - if apikey != "" { - req.Header.Set("apikey", apikey) - } - - bearerToken := string(ctx.RequestContext().Request.Header.Peek("Authorization")) - if bearerToken != "" && strings.HasPrefix(bearerToken, "Bearer ") { - bearerToken = strings.TrimSpace(strings.TrimPrefix(bearerToken, "Bearer ")) - req.Header.Set("Authorization", bearerToken) + if bypass { + if flag.Lookup("test.v") == nil { + req.Header.Set("apikey", getConfig().ServiceKey) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", getConfig().ServiceKey)) + } + } else { + apikey := string(ctx.RequestContext().Request.Header.Peek("apikey")) + if apikey != "" { + req.Header.Set("apikey", apikey) + } + bearerToken := string(ctx.RequestContext().Request.Header.Peek("Authorization")) + if bearerToken != "" && strings.HasPrefix(bearerToken, "Bearer ") { + req.Header.Set("Authorization", bearerToken) + } } for key, value := range headers { @@ -61,12 +67,18 @@ func PostgrestRequest(ctx raiden.Context, method string, url string, payload []b defer fasthttp.ReleaseResponse(res) if err := client.Do(req, res); err != nil { - return nil, res, err + return res, err } body := res.Body() - return body, res, nil + if result != nil { + if err := json.Unmarshal(body, result); err != nil { + return res, fmt.Errorf("failed to unmarshal response body: %w", err) + } + } + + return res, nil } func isAllowedMethod(method string) bool { diff --git a/pkg/db/update.go b/pkg/db/update.go index 89e57cd3..e0f05452 100644 --- a/pkg/db/update.go +++ b/pkg/db/update.go @@ -8,10 +8,10 @@ import ( "github.com/valyala/fasthttp" ) -func (q *Query) Update(p interface{}) ([]byte, error) { +func (q *Query) Update(p interface{}, model interface{}) error { jsonData, err := json.Marshal(p) if err != nil { - return nil, err + return err } url := q.GetUrl() @@ -29,10 +29,10 @@ func (q *Query) Update(p interface{}) ([]byte, error) { headers["Content-Type"] = "application/json" headers["Prefer"] = "return=representation" - body, _, err := PostgrestRequest(q.Context, fasthttp.MethodPatch, url, jsonData, headers) - if err != nil { - return nil, err + _, err0 := PostgrestRequestBind(q.Context, fasthttp.MethodPatch, url, jsonData, headers, q.ByPass, model) + if err0 != nil { + return err0 } - return body, nil + return nil } diff --git a/pkg/db/update_test.go b/pkg/db/update_test.go index 74ac9eed..e4ed2b70 100644 --- a/pkg/db/update_test.go +++ b/pkg/db/update_test.go @@ -14,10 +14,10 @@ func TestUpdate(t *testing.T) { CreatedAt: time.Now(), } - _, err := NewQuery(&mockRaidenContext). + err := NewQuery(&mockRaidenContext). Model(articleMockModel). Eq("id", 1). - Update(article) + Update(article, &articleMockModel) assert.NoError(t, err) } diff --git a/pkg/db/upsert.go b/pkg/db/upsert.go index 85107f88..d4e37f1b 100644 --- a/pkg/db/upsert.go +++ b/pkg/db/upsert.go @@ -15,10 +15,10 @@ const ( IgnoreDuplicates = "ignore-duplicates" ) -func (q *Query) Upsert(payload []interface{}, opt UpsertOptions) ([]byte, error) { +func (q *Query) Upsert(payload []interface{}, opt UpsertOptions) error { jsonData, err := json.Marshal(payload) if err != nil { - return nil, err + return err } url := q.GetUrl() @@ -27,10 +27,11 @@ func (q *Query) Upsert(payload []interface{}, opt UpsertOptions) ([]byte, error) headers["Content-Type"] = "application/json" headers["Prefer"] = "resolution=" + opt.OnConflict - resp, _, err := PostgrestRequest(q.Context, fasthttp.MethodPost, url, jsonData, headers) - if err != nil { - return nil, err + var a interface{} + _, err0 := PostgrestRequestBind(q.Context, fasthttp.MethodPost, url, jsonData, headers, q.ByPass, &a) + if err0 != nil { + return err0 } - return resp, nil + return nil } diff --git a/pkg/db/upsert_test.go b/pkg/db/upsert_test.go index 4f156a77..31dd54c0 100644 --- a/pkg/db/upsert_test.go +++ b/pkg/db/upsert_test.go @@ -36,7 +36,7 @@ func TestUpsert(t *testing.T) { OnConflict: MergeDuplicates, } - _, err := NewQuery(&mockRaidenContext). + err := NewQuery(&mockRaidenContext). Model(articleMockModel). Upsert(payload, opt)