Skip to content

Commit

Permalink
Better synchronization for websocket commands (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
ob-stripe authored Oct 22, 2019
1 parent 5cf3c20 commit 7c1de8c
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 76 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ export GO111MODULE := on
export GOBIN := $(shell pwd)/bin
export PATH := $(GOBIN):$(PATH)
export GOPROXY := https://gocenter.io
export GOLANGCI_LINT_VERSION := v1.21.0

# Install all the build and lint dependencies
setup:
curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh
curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s $(GOLANGCI_LINT_VERSION)
curl -L https://git.io/misspell | sh
go mod download
.PHONY: setup
Expand Down
3 changes: 2 additions & 1 deletion pkg/login/client_login.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package login

import (
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -113,7 +114,7 @@ func getLinks(baseURL string, deviceName string) (*Links, error) {
data := url.Values{}
data.Set("device_name", deviceName)

res, err := client.PerformRequest(http.MethodPost, stripeCLIAuthPath, data.Encode(), nil)
res, err := client.PerformRequest(context.TODO(), http.MethodPost, stripeCLIAuthPath, data.Encode(), nil)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/login/login_message.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package login

import (
"context"
"encoding/json"
"fmt"
"net/url"
Expand Down Expand Up @@ -72,7 +73,7 @@ func getUserAccount(baseURL string, apiKey string) (*Account, error) {
APIKey: apiKey,
}

resp, err := client.PerformRequest("GET", "/v1/account", "", nil)
resp, err := client.PerformRequest(context.TODO(), "GET", "/v1/account", "", nil)

if err != nil {
return nil, err
Expand Down
3 changes: 2 additions & 1 deletion pkg/login/poll.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package login

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -51,7 +52,7 @@ func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (*PollA

var count = 0
for count < maxAttempts {
res, err := client.PerformRequest(http.MethodGet, parsedURL.Path, parsedURL.Query().Encode(), nil)
res, err := client.PerformRequest(context.TODO(), http.MethodGet, parsedURL.Path, parsedURL.Query().Encode(), nil)
if err != nil {
return nil, nil, err
}
Expand Down
120 changes: 86 additions & 34 deletions pkg/logtailing/tailer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package logtailing

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -65,6 +66,8 @@ type Tailer struct {
webSocketClient *websocket.Client

interruptCh chan os.Signal

ctx context.Context
}

// EventPayload is the mapping for fields in event payloads from request log tailing
Expand Down Expand Up @@ -105,85 +108,134 @@ func New(cfg *Config) *Tailer {
}

// Run sets the websocket connection
func (tailer *Tailer) Run() error {
s := ansi.StartSpinner("Getting ready...", tailer.cfg.Log.Out)
func (t *Tailer) Run() error {
s := ansi.StartSpinner("Getting ready...", t.cfg.Log.Out)

// Intercept Ctrl+c so we can do some clean up
signal.Notify(tailer.interruptCh, os.Interrupt, syscall.SIGTERM)
// Create a context that will be canceled when Ctrl+C is pressed
ctx, cancel := context.WithCancel(context.Background())
t.ctx = ctx
signal.Notify(t.interruptCh, os.Interrupt, syscall.SIGTERM)

filters, err := jsonifyFilters(tailer.cfg.Filters)
if err != nil {
tailer.cfg.Log.Fatalf("Error while converting log filters to JSON encoding: %v", err)
}
go func() {
log.WithFields(log.Fields{
"prefix": "logtailing.Tailer.Run",
}).Debug("Ctrl+C received, cleaning up...")

session, err := tailer.stripeAuthClient.Authorize(tailer.cfg.DeviceName, tailer.cfg.WebSocketFeature, &filters)
<-t.interruptCh
cancel()
}()

// Create the CLI session
session, err := t.createSession()
if err != nil {
tailer.cfg.Log.Fatalf("Error while authenticating with Stripe: %v", err)
ansi.StopSpinner(s, "", t.cfg.Log.Out)
t.cfg.Log.Fatalf("Error while authenticating with Stripe: %v", err)
}

tailer.webSocketClient = websocket.NewClient(
// Create and start the websocket client
t.webSocketClient = websocket.NewClient(
session.WebSocketURL,
session.WebSocketID,
session.WebSocketAuthorizedFeature,
&websocket.Config{
EventHandler: websocket.EventHandlerFunc(tailer.processRequestLogEvent),
Log: tailer.cfg.Log,
NoWSS: tailer.cfg.NoWSS,
EventHandler: websocket.EventHandlerFunc(t.processRequestLogEvent),
Log: t.cfg.Log,
NoWSS: t.cfg.NoWSS,
ReconnectInterval: time.Duration(session.ReconnectDelay) * time.Second,
},
)
go tailer.webSocketClient.Run()

ansi.StopSpinner(s, "Ready! You're now waiting to receive API request logs (^C to quit)", tailer.cfg.Log.Out)
go t.webSocketClient.Run()

select {
case <-t.webSocketClient.Connected():
ansi.StopSpinner(s, "Ready! You're now waiting to receive API request logs (^C to quit)", t.cfg.Log.Out)
case <-t.ctx.Done():
ansi.StopSpinner(s, "", t.cfg.Log.Out)
t.cfg.Log.Fatalf("Aborting")
}

if session.DisplayConnectFilterWarning {
color := ansi.Color(os.Stdout)
fmt.Println(fmt.Sprintf("%s you specified the 'account' filter for connect accounts but are not a connect merchant, so the filter will not be applied.", color.Yellow("Warning")))
fmt.Println(fmt.Sprintf("%s you specified the 'account' filter for Connect accounts but are not a Connect user, so the filter will not be applied.", color.Yellow("Warning")))
}

// Block until Ctrl+C is received
<-tailer.interruptCh

log.WithFields(log.Fields{
"prefix": "logs.Tailer.Run",
}).Debug("Ctrl+C received, cleaning up...")
// Block until context is done (i.e. Ctrl+C is pressed)
<-t.ctx.Done()

if tailer.webSocketClient != nil {
tailer.webSocketClient.Stop()
if t.webSocketClient != nil {
t.webSocketClient.Stop()
}

log.WithFields(log.Fields{
"prefix": "logs.Tailer.Run",
"prefix": "logtailing.Tailer.Run",
}).Debug("Bye!")

return nil
}

func (tailer *Tailer) processRequestLogEvent(msg websocket.IncomingMessage) {
func (t *Tailer) createSession() (*stripeauth.StripeCLISession, error) {
var session *stripeauth.StripeCLISession

var err error

exitCh := make(chan struct{})

filters, err := jsonifyFilters(t.cfg.Filters)
if err != nil {
t.cfg.Log.Fatalf("Error while converting log filters to JSON encoding: %v", err)
}

go func() {
// Try to authorize at least 5 times before failing. Sometimes we have random
// transient errors that we just need to retry for.
for i := 0; i <= 5; i++ {
session, err = t.stripeAuthClient.Authorize(t.ctx, t.cfg.DeviceName, t.cfg.WebSocketFeature, &filters)

if err == nil {
exitCh <- struct{}{}
return
}

select {
case <-t.ctx.Done():
exitCh <- struct{}{}
return
case <-time.After(1 * time.Second):
}
}

exitCh <- struct{}{}
}()
<-exitCh

return session, err
}

func (t *Tailer) processRequestLogEvent(msg websocket.IncomingMessage) {
if msg.RequestLogEvent == nil {
tailer.cfg.Log.Debug("WebSocket specified for request logs received non-request-logs event")
t.cfg.Log.Debug("WebSocket specified for request logs received non-request-logs event")
return
}

requestLogEvent := msg.RequestLogEvent

tailer.cfg.Log.WithFields(log.Fields{
"prefix": "logs.Tailer.processRequestLogEvent",
t.cfg.Log.WithFields(log.Fields{
"prefix": "logtailing.Tailer.processRequestLogEvent",
"webhook_id": requestLogEvent.RequestLogID,
}).Debugf("Processing request log event")

var payload EventPayload
if err := json.Unmarshal([]byte(requestLogEvent.EventPayload), &payload); err != nil {
tailer.cfg.Log.Debug("Received malformed payload: ", err)
t.cfg.Log.Debug("Received malformed payload: ", err)
}

// Don't show stripecli/sessions logs since they're generated by the CLI
if payload.URL == "/v1/stripecli/sessions" {
tailer.cfg.Log.Debug("Filtering out /v1/stripecli/sessions from logs")
t.cfg.Log.Debug("Filtering out /v1/stripecli/sessions from logs")
return
}

if tailer.cfg.OutputFormat == outputFormatJSON {
if t.cfg.OutputFormat == outputFormatJSON {
fmt.Println(ansi.ColorizeJSON(requestLogEvent.EventPayload, false, os.Stdout))
return
}
Expand Down
81 changes: 59 additions & 22 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -84,37 +85,37 @@ type Proxy struct {
events map[string]bool

interruptCh chan os.Signal

ctx context.Context
}

// Run sets the websocket connection and starts the Goroutines to forward
// incoming events to the local endpoint.
func (p *Proxy) Run() error {
s := ansi.StartSpinner("Getting ready...", p.cfg.Log.Out)

// Intercept Ctrl+c so we can do some clean up
// Create a context that will be canceled when Ctrl+C is pressed
ctx, cancel := context.WithCancel(context.Background())
p.ctx = ctx
signal.Notify(p.interruptCh, os.Interrupt, syscall.SIGTERM)

var session *stripeauth.StripeCLISession

var err error

// Try to authorize at least 5 times before failing. Sometimes we have random
// transient errors that we just need to retry for.
for i := 0; i <= 5; i++ {
session, err = p.stripeAuthClient.Authorize(p.cfg.DeviceName, p.cfg.WebSocketFeature, nil)

if err == nil {
break
}
go func() {
log.WithFields(log.Fields{
"prefix": "proxy.Proxy.Run",
}).Debug("Ctrl+C received, cleaning up...")

time.Sleep(1 * time.Second)
}
<-p.interruptCh
cancel()
}()

// Create the CLI session
session, err := p.createSession()
if err != nil {
ansi.StopSpinner(s, "", p.cfg.Log.Out)
p.cfg.Log.Fatalf("Error while authenticating with Stripe: %v", err)
}

// Create and start the websocket client
p.webSocketClient = websocket.NewClient(
session.WebSocketURL,
session.WebSocketID,
Expand All @@ -124,18 +125,21 @@ func (p *Proxy) Run() error {
NoWSS: p.cfg.NoWSS,
ReconnectInterval: time.Duration(session.ReconnectDelay) * time.Second,
EventHandler: websocket.EventHandlerFunc(p.processWebhookEvent),
Ctx: p.ctx,
},
)
go p.webSocketClient.Run()

ansi.StopSpinner(s, fmt.Sprintf("Ready! Your webhook signing secret is %s (^C to quit)", ansi.Bold(session.Secret)), p.cfg.Log.Out)

// Block until Ctrl+C is received
<-p.interruptCh
select {
case <-p.webSocketClient.Connected():
ansi.StopSpinner(s, fmt.Sprintf("Ready! Your webhook signing secret is %s (^C to quit)", ansi.Bold(session.Secret)), p.cfg.Log.Out)
case <-p.ctx.Done():
ansi.StopSpinner(s, "", p.cfg.Log.Out)
p.cfg.Log.Fatalf("Aborting")
}

log.WithFields(log.Fields{
"prefix": "proxy.Proxy.Run",
}).Debug("Ctrl+C received, cleaning up...")
// Block until context is done (i.e. Ctrl+C is pressed)
<-p.ctx.Done()

if p.webSocketClient != nil {
p.webSocketClient.Stop()
Expand All @@ -148,6 +152,39 @@ func (p *Proxy) Run() error {
return nil
}

func (p *Proxy) createSession() (*stripeauth.StripeCLISession, error) {
var session *stripeauth.StripeCLISession

var err error

exitCh := make(chan struct{})

go func() {
// Try to authorize at least 5 times before failing. Sometimes we have random
// transient errors that we just need to retry for.
for i := 0; i <= 5; i++ {
session, err = p.stripeAuthClient.Authorize(p.ctx, p.cfg.DeviceName, p.cfg.WebSocketFeature, nil)

if err == nil {
exitCh <- struct{}{}
return
}

select {
case <-p.ctx.Done():
exitCh <- struct{}{}
return
case <-time.After(1 * time.Second):
}
}

exitCh <- struct{}{}
}()
<-exitCh

return session, err
}

func (p *Proxy) filterWebhookEvent(msg *websocket.WebhookEvent) bool {
if msg.Endpoint.APIVersion != nil && !p.cfg.UseLatestAPIVersion {
p.cfg.Log.WithFields(log.Fields{
Expand Down
3 changes: 2 additions & 1 deletion pkg/requests/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package requests

import (
"bufio"
"context"
"fmt"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -134,7 +135,7 @@ func (rb *Base) MakeRequest(apiKey, path string, params *RequestParameters, errO
rb.setVersionHeader(req, params)
}

resp, err := client.PerformRequest(rb.Method, path, data, configureReq)
resp, err := client.PerformRequest(context.TODO(), rb.Method, path, data, configureReq)
if err != nil {
return []byte{}, err
}
Expand Down
Loading

0 comments on commit 7c1de8c

Please sign in to comment.