diff --git a/clients/graphql/batch.go b/clients/graphql/batch.go index 078c64d..450a8b2 100644 --- a/clients/graphql/batch.go +++ b/clients/graphql/batch.go @@ -25,6 +25,10 @@ import ( func paginateBatch(inputs []*AlertsInput, response *graphql.Response) ([]*AlertsInput, []*protocol.AlertEvent, error) { // type-checking response + if response == nil { + return nil, nil, fmt.Errorf("nil graphql response") + } + batchAlertsResponseUnsafe, ok := response.Data.(*BatchGetAlertsResponse) if !ok { return nil, nil, fmt.Errorf("invalid pagination response") diff --git a/clients/graphql/client_test.go b/clients/graphql/client_test.go index 25633f0..d76a72f 100644 --- a/clients/graphql/client_test.go +++ b/clients/graphql/client_test.go @@ -1,32 +1,160 @@ package graphql import ( + "context" "fmt" + "net/http" + "net/http/httptest" "testing" + "time" + "github.com/forta-network/forta-core-go/protocol" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUnmarshal(t *testing.T) { - resp, data, err := parseResponse([]byte(testResponse)) - assert.NoError(t, err) + resp := parseBatchResponse([]byte(testResponse)) + data := (*resp.Data.(*BatchGetAlertsResponse))["alerts0"] + assert.NotNilf(t, resp, "graphql response can not be nil") - assert.NotNilf(t, data, "data can not be nil") for i := 0; i < 5; i++ { - assert.Equal(t, fmt.Sprintf("0x%d", i), data.Alerts.Alerts[i].Source.SourceAlert.Hash) - assert.Equal(t, "0xbbb", data.Alerts.Alerts[i].Source.SourceAlert.BotId) - assert.Equal(t, "2023-01-01T00:00:00Z", data.Alerts.Alerts[i].Source.SourceAlert.Timestamp) - assert.Equal(t, uint64(137), data.Alerts.Alerts[i].Source.SourceAlert.ChainId) - assert.Equal(t, "Block height: 17890044", data.Alerts.Alerts[i].Description) - assert.Equal(t, uint64(i), data.Alerts.Alerts[i].Source.Block.Number) + assert.Equal(t, fmt.Sprintf("0x%d", i), data.Alerts[i].Source.SourceAlert.Hash) + assert.Equal(t, "0xbbb", data.Alerts[i].Source.SourceAlert.BotId) + assert.Equal(t, "2023-01-01T00:00:00Z", data.Alerts[i].Source.SourceAlert.Timestamp) + assert.Equal(t, uint64(137), data.Alerts[i].Source.SourceAlert.ChainId) + assert.Equal(t, "Block height: 17890044", data.Alerts[i].Description) + assert.Equal(t, uint64(i), data.Alerts[i].Source.Block.Number) + } +} + +func TestGetAlertsBatch(t *testing.T) { + batchResp := parseBatchResponse([]byte(testResponse)) + responseItem := (*batchResp.Data.(*BatchGetAlertsResponse))["alerts0"] + expectedAlerts := responseItem.ToAlertEvents() + tests := []struct { + desc string + inputs []*AlertsInput + headers map[string]string + setupMock func(mux *http.ServeMux) + wantAlerts []*protocol.AlertEvent + wantErr bool + }{ + { + desc: "Successful Request", + inputs: []*AlertsInput{ + { + BlockSortDirection: "ASC", + CreatedSince: 30, + First: 3, + }, + }, + headers: map[string]string{ + "Authorization": "Bearer: token", + }, + setupMock: func(mux *http.ServeMux) { + // Here's a simple example of what your setup function might do: + mux.HandleFunc("/graphql", func(w http.ResponseWriter, r *http.Request) { + + fmt.Fprintf(w, testResponse) + }) + }, + wantAlerts: expectedAlerts, + wantErr: false, + }, + { + desc: "Failure due to server error", + headers: map[string]string{ + "Authorization": "Bearer: token", + }, + inputs: []*AlertsInput{ + { + Bots: []string{"0xabc"}, + }, + }, + setupMock: func(mux *http.ServeMux) { + // Here's a simple example of what your setup function might do: + mux.HandleFunc("/graphql", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "server error", http.StatusInternalServerError) + }) + }, + wantAlerts: nil, + wantErr: true, + }, + { + desc: "Failure due to unauthorized", + inputs: []*AlertsInput{ + { + Bots: []string{"0xabc"}, + }, + }, + headers: map[string]string{ + "Authorization": "", // No token + }, + setupMock: func(mux *http.ServeMux) { + // Here's a simple example of what your setup function might do: + mux.HandleFunc("/graphql", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + }) + }, + wantAlerts: nil, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + mux := http.NewServeMux() + ts := httptest.NewUnstartedServer(mux) + + tc.setupMock(mux) // Modify setupMock to accept *http.ServeMux + ts.Start() + defer ts.Close() + + // Prepare client + client := NewClient(fmt.Sprintf("%s/graphql", ts.URL)) + + // Get context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Invoke GetAlertsBatch + gotAlerts, gotErr := client.GetAlertsBatch(ctx, tc.inputs, tc.headers) + + if tc.wantErr { + require.Error(t, gotErr) + return + } + require.NoError(t, gotErr) + require.Equal(t, tc.wantAlerts, gotAlerts) + }) } } +func TestWS(t *testing.T) { + ctx := context.Background() + client := "http://localhost:8080/graphql" + // valid query + input1 := AlertsInput{ + BlockSortDirection: SortAsc, + First: 1, + } + // invalid query + input2 := AlertsInput{ + BlockSortDirection: SortAsc, + First: 0, + } + inputs := []*AlertsInput{&input1, &input2} + resp, err := fetchAlertsBatch(ctx, client, inputs, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + const testResponse = `{ "data": { - "alerts": { + "alerts0": { "pageInfo": { - "hasNextPage": true, + "hasNextPage": false, "endCursor": { "alertId": "0x0baefe6f0be064d7f3637af75a90964e7c231cb6c35266f51af2ce3539558b93", "blockNumber": 17890041 diff --git a/clients/graphql/models_test.go b/clients/graphql/models_test.go new file mode 100644 index 0000000..de6eb13 --- /dev/null +++ b/clients/graphql/models_test.go @@ -0,0 +1,98 @@ +package graphql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_createGetAlertsQuery(t *testing.T) { + testInput := AlertsInput{ + AlertId: "0xabc", + } + resp, variables := createGetAlertsQuery([]*AlertsInput{&testInput}) + assert.Equal(t, resp, mockExpectedQuery) + variable, ok := variables["input0"].(*AlertsInput) + assert.True(t, ok) + assert.Equal(t, variable.AlertId, testInput.AlertId) +} + +const mockExpectedQuery = `query getAlerts($input0: AlertsInput) {alerts0: alerts(input: $input0) { +pageInfo { + hasNextPage + endCursor { + alertId + blockNumber + } +} +alerts { + alertId + addresses + contracts { + name + projectId + } + createdAt + description + hash + metadata + name + projects { + id + } + protocol + scanNodeCount + severity + source { + transactionHash + bot { + chainIds + createdAt + description + developer + docReference + enabled + id + image + name + reference + repository + projects + scanNodes + version + } + block { + number + hash + timestamp + chainId + } + sourceAlert { + hash + botId + timestamp + chainId + } + } + alertDocumentType + findingType + relatedAlerts + chainId + labels { + label + confidence + entity + entityType + remove + metadata + uniqueKey + embedding + } + addressBloomFilter { + bitset + itemCount + k + m + } +} +}}`