From a6162fd6c19812cd44f27a446d53da3cf8829d2b Mon Sep 17 00:00:00 2001 From: Ehsan Noureddin Moosa Date: Fri, 24 May 2024 10:04:04 +0300 Subject: [PATCH] [fastws] fix reading very big messages --- kit/edge.go | 4 +- kit/edge_options.go | 9 ++++ kit/utils/generic.go | 6 +++ std/gateways/fastws/gateway.go | 75 ++++++++++++++++++++++------------ testenv/fastws_test.go | 33 +++++++++------ testenv/utils.go | 8 +++- 6 files changed, 96 insertions(+), 39 deletions(-) diff --git a/kit/edge.go b/kit/edge.go index 5045a6a1..31c6db32 100644 --- a/kit/edge.go +++ b/kit/edge.go @@ -38,6 +38,7 @@ type EdgeServer struct { // configs prefork bool + reusePort bool shutdownTimeout time.Duration // local store @@ -61,6 +62,7 @@ func NewServer(opts ...Option) *EdgeServer { s.l = cfg.logger s.prefork = cfg.prefork + s.reusePort = cfg.reusePort s.eh = cfg.errHandler s.gh = cfg.globalHandlers s.cd = cfg.connDelegate @@ -288,7 +290,7 @@ func (s *EdgeServer) startup(ctx context.Context) { err := s.nb[idx].gw.Start( ctx, GatewayStartConfig{ - ReusePort: s.prefork, + ReusePort: s.prefork || s.reusePort, }, ) if err != nil { diff --git a/kit/edge_options.go b/kit/edge_options.go index ca37e618..a7ed2a6a 100644 --- a/kit/edge_options.go +++ b/kit/edge_options.go @@ -5,6 +5,7 @@ import "time" type edgeConfig struct { logger Logger prefork bool + reusePort bool shutdownTimeout time.Duration gateways []Gateway cluster Cluster @@ -101,3 +102,11 @@ func WithConnDelegate(d ConnDelegate) Option { s.connDelegate = d } } + +// ReusePort asks Gateway to listen in REUSE_PORT mode. +// default this is false +func ReusePort(t bool) Option { + return func(s *edgeConfig) { + s.reusePort = t + } +} diff --git a/kit/utils/generic.go b/kit/utils/generic.go index bd7c3416..746f0e05 100644 --- a/kit/utils/generic.go +++ b/kit/utils/generic.go @@ -48,3 +48,9 @@ func OkOr[T any](v T, err error, fallback T) T { return v } + +func TryCast[T any](v any) T { + t, _ := v.(T) + + return t +} diff --git a/std/gateways/fastws/gateway.go b/std/gateways/fastws/gateway.go index 2a614009..3c979c89 100644 --- a/std/gateways/fastws/gateway.go +++ b/std/gateways/fastws/gateway.go @@ -13,7 +13,6 @@ import ( "github.com/clubpay/ronykit/kit/utils" "github.com/clubpay/ronykit/kit/utils/buf" "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" "github.com/panjf2000/gnet/v2" ) @@ -88,6 +87,8 @@ func (gw *gateway) OnClose(c gnet.Conn, _ error) (action gnet.Action) { func (gw *gateway) OnTraffic(c gnet.Conn) gnet.Action { wsc := gw.getConnWrap(c) if wsc == nil { + gw.b.l.Debugf("did not find ws conn for connID(%d)", utils.TryCast[int64](c.Context())) + return gnet.Close } @@ -96,6 +97,7 @@ func (gw *gateway) OnTraffic(c gnet.Conn) gnet.Action { _, err := sp.Upgrade(wsc.c) if err != nil { wsc.Close() + gw.b.l.Debugf("faild to upgrade websocket connID(%d): %v", utils.TryCast[int64](c.Context()), err) return gnet.Close } @@ -111,46 +113,69 @@ func (gw *gateway) OnTraffic(c gnet.Conn) gnet.Action { hdr ws.Header ) - hdr, err = wsc.r.NextFrame() - if err != nil { - if builtinErr.Is(err, io.EOF) { + for { + hdr, err = wsc.r.NextFrame() + if err != nil { + if builtinErr.Is(err, io.EOF) { + return gnet.None + } + gw.b.l.Debugf("failed to read next frame of connID(%d): %v", utils.TryCast[int64](c.Context()), err) + + return gnet.Close + } + + if hdr.OpCode.IsControl() { + if err = wsc.r.OnIntermediate(hdr, wsc.r); err != nil { + gw.b.l.Debugf( + "failed to handle control message of connID(%d), opCode(%d): %v", + utils.TryCast[int64](c.Context()), hdr.OpCode, err, + ) + + return gnet.Close + } + + if err = wsc.r.Discard(); err != nil { + gw.b.l.Debugf( + "failed to discard on control message connID(%d): %v", + utils.TryCast[int64](c.Context()), err, + ) + + return gnet.Close + } + return gnet.None } - return gnet.Close + if hdr.OpCode&(ws.OpText|ws.OpBinary) != hdr.OpCode { + if err = wsc.r.Discard(); err != nil { + return gnet.Close + } + + continue + } + + break } - var p []byte + var pBuff *buf.Bytes if hdr.Fin { // No more frames will be read. Use fixed sized buffer to read payload. - p = make([]byte, hdr.Length) // It is not possible to receive io.EOF here because Reader does not // return EOF if frame payload was successfully fetched. - _, err = io.ReadFull(wsc.r, p) + pBuff = buf.GetLen(int(hdr.Length)) + _, err = io.ReadFull(wsc.r, *pBuff.Bytes()) } else { - // Frame is fragmented, thus use io.ReadAll behavior. - var buff bytes.Buffer + // create a default buffer cap, since we don't know the exact size of payload + pBuff = buf.GetCap(8192) + buff := bytes.NewBuffer(*pBuff.Bytes()) _, err = buff.ReadFrom(wsc.r) - p = buff.Bytes() + pBuff.SetBytes(utils.ValPtr(buff.Bytes())) } if err != nil { return gnet.Close } - wsc.msgs = append(wsc.msgs, wsutil.Message{OpCode: hdr.OpCode, Payload: p}) - - for _, msg := range wsc.msgs { - if msg.OpCode&(ws.OpText|ws.OpBinary) == 0 { - continue - } - - payloadBuffer := buf.GetLen(len(msg.Payload)) - payloadBuffer.CopyFrom(msg.Payload) - - go gw.reactFunc(wsc, payloadBuffer, len(msg.Payload)) - } - - wsc.msgs = wsc.msgs[:0] + go gw.reactFunc(wsc, pBuff, pBuff.Len()) return gnet.None } diff --git a/testenv/fastws_test.go b/testenv/fastws_test.go index 7103b3a1..16e10514 100644 --- a/testenv/fastws_test.go +++ b/testenv/fastws_test.go @@ -37,6 +37,7 @@ func invokeEdgeServerWithFastWS(port int, desc ...kit.ServiceDescriptor) fx.Opti return fx.Invoke( func(lc fx.Lifecycle) { edge := kit.NewServer( + kit.ReusePort(true), kit.WithLogger(&stdLogger{}), kit.WithErrorHandler( func(ctx *kit.Context, err error) { @@ -47,6 +48,7 @@ func invokeEdgeServerWithFastWS(port int, desc ...kit.ServiceDescriptor) fx.Opti fastws.MustNew( fastws.WithPredicateKey("cmd"), fastws.Listen(fmt.Sprintf("tcp4://0.0.0.0:%d", port)), + fastws.WithLogger(&stdLogger{}), ), ), kit.WithServiceDesc(desc...), @@ -81,23 +83,30 @@ func fastwsWithHugePayload(t *testing.T, opt fx.Option) func(c C) { ), ) - time.Sleep(time.Second * 5) + time.Sleep(time.Second * 2) - wsCtx := stub.New("localhost:8082"). + wsCtx := stub.New( + "localhost:8082", + //stub.WithLogger(&stdLogger{}), + ). Websocket( stub.WithPredicateKey("cmd"), + stub.WithPingTime(time.Second), ) c.So(wsCtx.Connect(ctx, "/"), ShouldBeNil) - req := &services.EchoRequest{Input: utils.RandomID(12000)} - res := &services.EchoResponse{} - err := wsCtx.BinaryMessage( - ctx, "echo", req, res, - func(ctx context.Context, msg kit.Message, hdr stub.Header, err error) { - c.So(err, ShouldBeNil) - c.So(msg.(*services.EchoResponse).Output, ShouldEqual, req.Input) //nolint:forcetypeassert - }, - ) - c.So(err, ShouldBeNil) + for i := 0; i < 10; i++ { + req := &services.EchoRequest{Input: utils.RandomID(1024)} + res := &services.EchoResponse{} + err := wsCtx.BinaryMessage( + ctx, "echo", req, res, + func(ctx context.Context, msg kit.Message, hdr stub.Header, err error) { + c.So(err, ShouldBeNil) + c.So(msg.(*services.EchoResponse).Output, ShouldEqual, req.Input) //nolint:forcetypeassert + }, + ) + c.So(err, ShouldBeNil) + time.Sleep(time.Second) + } } } diff --git a/testenv/utils.go b/testenv/utils.go index 3e6eba54..f1f097ab 100644 --- a/testenv/utils.go +++ b/testenv/utils.go @@ -1,9 +1,15 @@ package testenv -import "fmt" +import ( + "fmt" + + "github.com/clubpay/ronykit/kit" +) type stdLogger struct{} +var _ kit.Logger = (*stdLogger)(nil) + func (s stdLogger) Debugf(format string, args ...any) { fmt.Printf("DEBUG: %s\n", fmt.Sprintf(format, args...)) }