Skip to content

Commit

Permalink
support graceful close for websocket client (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
linstohu authored Dec 26, 2023
1 parent 92aeeb3 commit 46184ac
Show file tree
Hide file tree
Showing 39 changed files with 807 additions and 575 deletions.
70 changes: 49 additions & 21 deletions binance/coinmfutures/websocketmarket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ type CoinMarginedMarketStreamClient struct {
// logger
logger *slog.Logger

ctx context.Context
stopCtx context.Context
cancel context.CancelFunc

conn *websocket.Conn
mu sync.RWMutex
isConnected bool
Expand All @@ -56,13 +58,14 @@ type CoinMarginedMarketStreamClient struct {
}

type CoinMarginedMarketStreamCfg struct {
BaseURL string `validate:"required"`
Debug bool
// Logger
Debug bool
BaseURL string `validate:"required"`
AutoReconnect bool `validate:"required"`

Logger *slog.Logger
}

func NewMarketStreamClient(ctx context.Context, cfg *CoinMarginedMarketStreamCfg) (*CoinMarginedMarketStreamClient, error) {
func NewMarketStreamClient(cfg *CoinMarginedMarketStreamCfg) (*CoinMarginedMarketStreamClient, error) {
if err := validator.New().Struct(cfg); err != nil {
return nil, err
}
Expand All @@ -72,8 +75,7 @@ func NewMarketStreamClient(ctx context.Context, cfg *CoinMarginedMarketStreamCfg
debug: cfg.Debug,
logger: cfg.Logger,

ctx: ctx,
autoReconnect: true,
autoReconnect: cfg.AutoReconnect,

subscriptions: cmap.New[struct{}](),
emitter: emission.NewEmitter(),
Expand All @@ -83,12 +85,33 @@ func NewMarketStreamClient(ctx context.Context, cfg *CoinMarginedMarketStreamCfg
cli.logger = slog.Default()
}

err := cli.start()
return cli, nil
}

func (u *CoinMarginedMarketStreamClient) Open() error {
if u.stopCtx != nil {
return fmt.Errorf("%s: ws is already open", logPrefix)
}

u.stopCtx, u.cancel = context.WithCancel(context.Background())

err := u.start()
if err != nil {
return nil, err
return err
}

return cli, nil
return nil
}

func (u *CoinMarginedMarketStreamClient) Close() error {
if u.stopCtx == nil {
return fmt.Errorf("%s: ws is not open", logPrefix)
}

u.cancel()
u.stopCtx = nil

return nil
}

func (u *CoinMarginedMarketStreamClient) start() error {
Expand All @@ -99,7 +122,7 @@ func (u *CoinMarginedMarketStreamClient) start() error {
for i := 0; i < MaxTryTimes; i++ {
conn, _, err := u.connect()
if err != nil {
u.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error()))
u.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error()))
tm := (i + 1) * 5
time.Sleep(time.Duration(tm) * time.Second)
continue
Expand All @@ -111,6 +134,8 @@ func (u *CoinMarginedMarketStreamClient) start() error {
return errors.New("connect failed")
}

u.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, u.baseURL))

u.setIsConnected(true)

u.resubscribe()
Expand Down Expand Up @@ -141,15 +166,14 @@ func (u *CoinMarginedMarketStreamClient) reconnect() {

u.setIsConnected(false)

u.logger.Info("disconnect, then reconnect...")

time.Sleep(1 * time.Second)

select {
case <-u.ctx.Done():
u.logger.Info(fmt.Sprintf("never reconnect, %s", u.ctx.Err()))
case <-u.stopCtx.Done():
u.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix))
return
default:
u.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix))
u.start()
}
}
Expand Down Expand Up @@ -185,24 +209,28 @@ func (u *CoinMarginedMarketStreamClient) IsConnected() bool {
func (u *CoinMarginedMarketStreamClient) readMessages() {
for {
select {
case <-u.ctx.Done():
u.logger.Info(fmt.Sprintf("context done, error: %s", u.ctx.Err().Error()))
case <-u.stopCtx.Done():
u.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix))

if err := u.close(); err != nil {
u.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error()))
u.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error()))
return
}

u.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix))
return
default:
var msg utils.AnyMessage
err := u.conn.ReadJSON(&msg)
if err != nil {
u.logger.Info(fmt.Sprintf("read object error, %s", err))
u.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err))

if err := u.close(); err != nil {
u.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error()))
u.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error()))
return
}

u.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix))
return
}

Expand All @@ -212,7 +240,7 @@ func (u *CoinMarginedMarketStreamClient) readMessages() {
case msg.SubscribedMessage != nil:
err := u.handle(msg.SubscribedMessage)
if err != nil {
u.logger.Info(fmt.Sprintf("handle message error: %s", err.Error()))
u.logger.Info(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error()))
}
}
}
Expand Down
Loading

0 comments on commit 46184ac

Please sign in to comment.