Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support graceful close for websocket client #3

Merged
merged 1 commit into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions binance/coinmfutures/websocketmarket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ type CoinMarginedMarketStreamClient struct {
// logger
logger *slog.Logger

ctx context.Context
stopCtx context.Context
cancel context.CancelFunc

conn *websocket.Conn
mu sync.RWMutex
isConnected bool
Expand All @@ -56,13 +58,14 @@ type CoinMarginedMarketStreamClient struct {
}

type CoinMarginedMarketStreamCfg struct {
BaseURL string `validate:"required"`
Debug bool
// Logger
Debug bool
BaseURL string `validate:"required"`
AutoReconnect bool `validate:"required"`

Logger *slog.Logger
}

func NewMarketStreamClient(ctx context.Context, cfg *CoinMarginedMarketStreamCfg) (*CoinMarginedMarketStreamClient, error) {
func NewMarketStreamClient(cfg *CoinMarginedMarketStreamCfg) (*CoinMarginedMarketStreamClient, error) {
if err := validator.New().Struct(cfg); err != nil {
return nil, err
}
Expand All @@ -72,8 +75,7 @@ func NewMarketStreamClient(ctx context.Context, cfg *CoinMarginedMarketStreamCfg
debug: cfg.Debug,
logger: cfg.Logger,

ctx: ctx,
autoReconnect: true,
autoReconnect: cfg.AutoReconnect,

subscriptions: cmap.New[struct{}](),
emitter: emission.NewEmitter(),
Expand All @@ -83,12 +85,33 @@ func NewMarketStreamClient(ctx context.Context, cfg *CoinMarginedMarketStreamCfg
cli.logger = slog.Default()
}

err := cli.start()
return cli, nil
}

func (u *CoinMarginedMarketStreamClient) Open() error {
if u.stopCtx != nil {
return fmt.Errorf("%s: ws is already open", logPrefix)
}

u.stopCtx, u.cancel = context.WithCancel(context.Background())

err := u.start()
if err != nil {
return nil, err
return err
}

return cli, nil
return nil
}

func (u *CoinMarginedMarketStreamClient) Close() error {
if u.stopCtx == nil {
return fmt.Errorf("%s: ws is not open", logPrefix)
}

u.cancel()
u.stopCtx = nil

return nil
}

func (u *CoinMarginedMarketStreamClient) start() error {
Expand All @@ -99,7 +122,7 @@ func (u *CoinMarginedMarketStreamClient) start() error {
for i := 0; i < MaxTryTimes; i++ {
conn, _, err := u.connect()
if err != nil {
u.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error()))
u.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error()))
tm := (i + 1) * 5
time.Sleep(time.Duration(tm) * time.Second)
continue
Expand All @@ -111,6 +134,8 @@ func (u *CoinMarginedMarketStreamClient) start() error {
return errors.New("connect failed")
}

u.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, u.baseURL))

u.setIsConnected(true)

u.resubscribe()
Expand Down Expand Up @@ -141,15 +166,14 @@ func (u *CoinMarginedMarketStreamClient) reconnect() {

u.setIsConnected(false)

u.logger.Info("disconnect, then reconnect...")

time.Sleep(1 * time.Second)

select {
case <-u.ctx.Done():
u.logger.Info(fmt.Sprintf("never reconnect, %s", u.ctx.Err()))
case <-u.stopCtx.Done():
u.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix))
return
default:
u.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix))
u.start()
}
}
Expand Down Expand Up @@ -185,24 +209,28 @@ func (u *CoinMarginedMarketStreamClient) IsConnected() bool {
func (u *CoinMarginedMarketStreamClient) readMessages() {
for {
select {
case <-u.ctx.Done():
u.logger.Info(fmt.Sprintf("context done, error: %s", u.ctx.Err().Error()))
case <-u.stopCtx.Done():
u.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix))

if err := u.close(); err != nil {
u.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error()))
u.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error()))
return
}

u.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix))
return
default:
var msg utils.AnyMessage
err := u.conn.ReadJSON(&msg)
if err != nil {
u.logger.Info(fmt.Sprintf("read object error, %s", err))
u.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err))

if err := u.close(); err != nil {
u.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error()))
u.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error()))
return
}

u.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix))
return
}

Expand All @@ -212,7 +240,7 @@ func (u *CoinMarginedMarketStreamClient) readMessages() {
case msg.SubscribedMessage != nil:
err := u.handle(msg.SubscribedMessage)
if err != nil {
u.logger.Info(fmt.Sprintf("handle message error: %s", err.Error()))
u.logger.Info(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error()))
}
}
}
Expand Down
Loading